In [None]:
"""
UTF-8, Python 3

------------------
Flaring SPI
------------------

Ekaterina Ilin, 2021, MIT License

De-trending Kepler and TESS

- get table of Kepler exoplanet system light curves
- fetch FLC
- get system info from table
- mask transits
- apply custom detrending
- search flares
- save results
"""


import copy
import time

from funcs.notebook import *
from funcs.detrend import estimate_detrended_noise, fit_spline, remove_sines_iteratively, remove_exponential_fringes
from funcs.transitmask import get_full_transit_mask

from altaipony.lcio import from_mast, from_path
from altaipony.flarelc import FlareLightCurve
from altaipony.altai import find_iterative_median

from lightkurve import search_lightcurvefile

from astropy.io import fits


def custom_detrending(lc, spline_coarseness=30, spline_order=3,
                      savgol1=6., savgol2=3., pad=3):
    """Custom de-trending for TESS and Kepler 
    short cadence light curves, including TESS Cycle 3 20s
    cadence.
    
    Parameters:
    ------------
    lc : FlareLightCurve
        light curve that has at least time, flux and flux_err
    spline_coarseness : float
        time scale in hours for spline points. 
        See fit_spline for details.
    spline_order: int
        Spline order for the coarse spline fit.
        Default is cubic spline.
    savgol1 : float
        Window size for first Savitzky-Golay filter application.
        Unit is hours, defaults to 6 hours.
    savgol2 : float
        Window size for second Savitzky-Golay filter application.
        Unit is hours, defaults to 3 hours.
    pad : 3
        Outliers in Savitzky-Golay filter are padded with this
        number of data points. Defaults to 3.
        
    Return:
    -------
    FlareLightCurve with detrended_flux attribute
    """
    dt = np.mean(np.diff(lc.time.value))

    # fit a spline to the general trends
    lc1, model = fit_spline(lc, spline_order=spline_order,
                            spline_coarseness=spline_coarseness)
    
    # replace for next step
    lc1.flux = lc1.detrended_flux.value

    # removes strong and fast variability on 5 day to 4.8 hours 
    # simple sines are probably because rotational variability is 
    # either weak and transient or strong and persistent on the timescales
    lc2 = remove_sines_iteratively(lc1)
    
    # choose a 6 hour window
    w = int((np.rint(savgol1 / 24. / dt) // 2) * 2 + 1)

    # use Savitzy-Golay to iron out the rest
    lc3 = lc2.detrend("savgol", window_length=w, pad=pad)

    # choose a three hour window
    w = int((np.rint(savgol2 / 24. / dt) // 2) * 2 + 1)

    # use Savitzy-Golay to iron out the rest
    lc4 = lc3.detrend("savgol", window_length=w, pad=pad)

    # find median value
    lc4 = find_iterative_median(lc4)

    # replace for next step
    lc4.flux = lc4.detrended_flux.value
    
    # remove exopential fringes that neither spline, 
    # nor sines, nor SavGol can remove.
    lc5 = remove_exponential_fringes(lc4)
  
    return lc5


def add_meta_data_and_write(ff, dflcn, ID, TIC, sector, mission,
                  lc_n, w, tstamp, mask_pos_outliers_sigma):
    """Write out flare table to file."""
    
    
    if ff.shape[0]==0:
        ff["phase"]=-1
        ff["total_n_valid_data_points"] = dflcn.detrended_flux.shape[0]
        ff["ID"] = ID
        ff["TIC"] = TIC
        ff["qcs"] = sector
        ff["mission"] = mission
        ff["tstamp"] = tstamp
        ff["lc_n"] = lc_n
        ff["w"] = w
        ff["mask_pos_outliers_sigma"] = mask_pos_outliers_sigma
        ff["real"]=-1
        ff = ff.append({"phase":-1,
                        "total_n_valid_data_points":dflcn.detrended_flux.shape[0],
                        "ID":ID,
                        "TIC":TIC,
                        "qcs" : sector,
                        "mission":mission,
                        "tstamp":tstamp,
                        "lc_n":lc_n,
                        "w":w,
                        "mask_pos_outliers_sigma":mask_pos_outliers_sigma,
                        "real":-1},
                         ignore_index=True)

    # otherwise add ID, QCS and mission
    else:
        ff["total_n_valid_data_points"] = dflcn.detrended_flux.shape[0]
        ff["ID"] = ID
        ff["TIC"] = TIC
        ff["qcs"] = sector
        ff["mission"] = mission
        ff["tstamp"] = tstamp
        ff["lc_n"] = lc_n
        ff["w"] = w
        ff["mask_pos_outliers_sigma"] = mask_pos_outliers_sigma

    # add results to file
    with open("../results/2022_07_flares.csv", "a") as file:
        ff.to_csv(file, index=False, header=False)
            
def write_flc_to_file(dflcn, flc, path_dflcn):
    """Write detrended light curve to fits."""
    
    
    dflcn.to_fits(path_dflcn, 
                  FLUX=flc.flux.value,
                  DETRENDED_FLUX=dflcn.detrended_flux.value,
                  DETRENDED_FLUX_ERR=dflcn.detrended_flux_err.value,
                  IT_MED=dflcn.it_med.value,
                  FLUX_MODEL=dflcn.flux_model.value,
                  PHASE = dflcn.phase,
                  overwrite=True)

def write_no_lc(input_target):
    with open("../results/2022_07_nolc.txt","a") as f:
        s = f"TIC {input_target.TIC}\n"
        f.write(s)
    
    
sep = "-----------------------------------------"

def mprint(message):
    print(sep)
    print(message)
    print(sep)
    
offset = {"K2":2454833.,
          "Kepler":2454833.,
          "TESS":2457000.,
          'Transiting Exoplanet Survey Satellite (TESS)':2457000.}    

In [None]:
def run_analysis(flc, input_target, sector, mission, lc_n, download_dir,
                 i=0, mask_pos_outliers_sigma = 2.5, addtail = True):
    # get timestamp for result
    tstamp = time.strftime("%Y_%m_%d", time.localtime())
    print(f"date: {tstamp}")

    dflc = custom_detrending(flc)
    print("LC successfully detrended.")

    # define two hour window for rolling std
    w = np.floor(1. / 12. / np.nanmin(np.diff(dflc.time.value)))
    if w%2==0: 
        w+=1

    # use window to estimate the noise in the LC
    dflcn = estimate_detrended_noise(dflc, std_window=int(w), 
                                  mask_pos_outliers_sigma=mask_pos_outliers_sigma)

    # search the residual for flares
    ff = dflcn.find_flares(addtail=addtail).flares


    # calculate the observed phases
    # calculate midtime of transit in TESS or Kepler time
    if mission == "TESS":
        if np.isfinite(input_target.pl_tranmid_tess):
            midtime = input_target.pl_tranmid_tess - offset[mission]
        else:
            midtime = input_target.pl_tranmid - offset[mission]
    elif mission == "Kepler":
        midtime = input_target.pl_tranmid - offset[mission]
    print(f"Transit midtime in {mission} time: {midtime}")

    # calculate phases for the light curve
    dflcn['phase'] = ((dflcn.time.value - midtime) % input_target.pl_orbper) / input_target.pl_orbper

    # calculate the phase at which the flare was observed
    ff["phase"] = ff.cstart.apply(lambda x: dflcn["phase"][np.where(x==dflcn.cadenceno)][0])
    

    # this is just to get the order of columns right, will be added later again
    if ff.shape[0]>0:
        del ff["total_n_valid_data_points"]

    # chop out all phases where we have no data points to look for flares in:
    dflcn["phase"][~np.isfinite(dflcn["detrended_flux"])] = np.nan

    fshow = ff[["tstart",'tstop',"phase","ampl_rec","dur"]]
    if fshow.shape[0]>0:
        print(f"Flares found:\n{fshow}")
    else:
        print(f'No flares found in LC.')

    # add meta info to flare table
    # if no flares found, add empty row and write to file
    add_meta_data_and_write(ff, dflcn, input_target.hostname, 
                            input_target.TIC, sector,
                            mission, lc_n, w, tstamp,
                            mask_pos_outliers_sigma)


    #write out detrended light curve
    if mission=="TESS":
        path_dflcn = f"{download_dir}/{tstamp}_{input_target.TIC}_{sector}_altai_{i}.fits"
    elif mission=="Kepler":
        name = input_target.hostname.replace(" ","_").replace("-","_")
        path_dflcn = f"{download_dir}/{tstamp}_{input_target.hostname}_{sector}_altai_{i}.fits"
        
    write_flc_to_file(dflcn, flc, path_dflcn)
    print(f"Wrote out LC to {path_dflcn}.")

    return ff.shape[0]

def get_table_of_light_curves(input_target):

    try:
        lcs  = search_lightcurvefile(input_target.hostname)   
        conditions = (lcs.exptime.value < 130) & (lcs.author != "TASOC")
        lcs_sel = lcs[conditions]
    except KeyError:
        try:
            lcs  = search_lightcurvefile(f"TIC {input_target.TIC}")
            conditions = (lcs.exptime.value < 130) & (lcs.author != "TASOC")
            lcs_sel = lcs[conditions]
        except KeyError:
            write_no_lc(input_target)
            return
    if len(lcs_sel)==0:
        write_no_lc(input_target)
        return
   
    
    lcs_sel = lcs_sel.table.to_pandas().sort_values(by="t_exptime")
    lcs_sel_tess = lcs_sel.loc[lcs_sel.mission.str[:4]=="TESS",:].drop_duplicates(subset=["mission"],keep="first")
    lcs_sel_kepler = lcs_sel[lcs_sel.mission.str[:6]=="Kepler"].drop_duplicates(subset=["mission"])

    lcs_sel = pd.concat([lcs_sel_kepler,lcs_sel_tess])
   
    return lcs_sel

In [None]:
# Composite Table of confirmed exoplanets
path = "../data/2022_07_27_input_catalog_star_planet_systems.csv"

mprint(f"[UP] Using compiled input catalog from {path}")

input_catalog = pd.read_csv(path) 

In [None]:
lcs_sel=pd.DataFrame()
lcs_sel.shape[0]

In [None]:
input_catalog[input_catalog.hostname == "Kepler-1313"]


In [None]:
count = 0
#Kepler-24 is the next one in line!

In [None]:
Nflares = 0 
while Nflares < 1000:
    print(f"\nCOUNT: {count}\n")
    lcs_sel=pd.DataFrame()
    print(lcs_sel.shape)
    while lcs_sel.shape[0]==0:
        input_target = input_catalog.iloc[count]

        lcs_sel = get_table_of_light_curves(input_target)
        if lcs_sel is None:
            lcs_sel=pd.DataFrame()
        count+=1

    TIC = "TIC " + str(input_target.TIC)
    ID = input_target.hostname
    n = 0
    Nflares = 0
    while n<lcs_sel.shape[0]:

        sector = lcs_sel.iloc[n].mission[-2:]
        mission = lcs_sel.iloc[n].mission.split(" ")[0]

        lc_n = n + 1

        if lcs_sel.iloc[n].exptime < 30:
            cadence = "fast"
        else: 
            cadence = "short"

        print(f"Get {mission} Sector/Quarter {sector}, {TIC}, {ID}, {cadence} cadence.")

        # fetch light curve from MAST
        download_dir = "/home/ekaterina/Documents/001_science/lcs"

        if mission=="TESS":
            flc = from_mast(TIC, mission=mission, c=sector,
                        cadence=cadence, author="SPOC",
                        download_dir=download_dir)
            if flc is None:
                print(f"No LC found for {mission}, {ID}, Quarter {sector}.")
                with open("../results/2022_07_listed_but_nothing_found.txt", "a") as f:
                    string = f"{mission},{ID},{TIC},{sector},{cadence}\n"
                    f.write(string)
                n += 1
            else:
                Nflares += run_analysis(flc, input_target, sector, mission, lc_n, download_dir, i=0)
                n += 1
        elif mission=="Kepler":
            flcl = from_mast(ID, mission=mission, c=sector,
                        cadence=cadence,
                        download_dir=download_dir)

            if flcl is None:
                print(f"No LC found for {mission}, {ID}, Quarter {sector}.")
                with open("../results/2022_07_listed_but_nothing_found.txt", "a") as f:
                    string = f"{mission},{ID},{TIC},{sector},{cadence}\n"
                    f.write(string)
                n += 1

            elif type(flcl) != list:

                print(f"1 LC found for {mission}, {ID}, Quarter {sector}.")
                Nflares += run_analysis(flcl, input_target, sector, mission, lc_n, download_dir, i=0)
                n += 1



            else:
                print(f"{len(flcl)} LCs found for {mission}, {ID}, Quarter {sector}.")
                for i, flc in enumerate(flcl):
                    Nflares += run_analysis(flc, input_target, sector, mission, lc_n, download_dir, i=i)

                n += 1

    print(f"\n---------------------\n{Nflares} flares found!\n-------------------\n")
print(f"\nNext count is {count}.\n")

In [None]:
type(flcl) != list

In [None]:
# 2022_07_29_Kepler-974_14_altai_2
# 2022_07_29_Kepler-974_15_altai_2
# 2022_07_29_Kepler-974_16_altai_1.fits +2
fff = fits.open("/home/ekaterina/Documents/001_science/lcs/2022_07_29_Kepler-350_14_altai_2.fits")[1].data
ff = pd.read_csv("../results/2022_07_flares.csv")
ff = ff[(ff.ID == ID) & (ff.qcs==14)]
ff

In [None]:
%matplotlib inline
for j, flare in ff.iloc[2:].iterrows():
    plt.figure(figsize=(16,5))
    cap=.5
    ts, tf = flare.tstart, flare.tstop
    print(ts,tf)
    _ = fff[np.where((fff['time']>=ts-.1/cap) & (fff['time']<=tf+.1/cap))]
    med = np.median(_['flux'])
    plt.plot(_['time'], _['flux']/med, c="k")
    
    _ = fff[np.where((fff['time']>=ts-1e-8) & (fff['time']<=tf+1e-8))]
    plt.scatter(_['time'], _['flux']/med, c="r")

    plt.scatter(ts, flare.phase/5 + 1)
    plt.plot(_['time'], _['phase']/5 + 1, c="grey")
#     plt.ylim(0.99,1.004)