In [None]:
def make_prior_dict(sho_trace, low_freqs, high_freqs):
    """
    Generates a dictionary of (median, std) for each hyperparameter from a GP noise model
    The expected sho_trace should be the output of a PyMC3/Exoplanet model built with noise.build_sho_model()
    
    Assumes a specific set of input variable names from sho_trace:
      - ['logw0', 'logSw4', 'logQ', 'logw0_x', 'logSw4_x', 'logQ_x', 'logS1', 'logS2', 'logQ1']
      - cannot have, e.g. both logw0 & logw0_x; both will be mapped to logw0
      - can have any other subset of inputs
      - maps logQ1 to logQh
      
    Parameters
    ----------
    sho_trace : PyMC3 multitrace
        trace output of a PyMC3/Exoplanet model built with noise.build_sho_model()
    low_freqs : list
        fixed (regular, not angular) frequency corresponding to 'logSw4'
    high_freqs : list
        fixed (regular, not angular) frequencies corresponding to 'logS1' and 'logS2'
        
    Returns
    -------
    priors : dict
        Dictionary keys can be any combination of ['logw0', 'logSw4', 'logQ', 'logS1', 'logS2', 'logQh', 'f1', 'f2']
        Each entry is a tuple of (median, std, percentile_01, percentile_99)
    """
    priors = {}
    
    varnames = sho_trace.varnames
    
    # check for redundancies
    if np.isin('logw0', varnames) & np.isin('logw0_x', varnames):
        raise ValueError('Expected only one of logw0 or logw0_x')
    if np.isin('logSw4', varnames) & np.isin('logSw4_x', varnames):
        raise ValueError('Expected only one of logSw4 or logSw4_x')
    if np.isin('logQ', varnames) & np.isin('logQ_x', varnames):
        raise ValueError('Expected only one of logQ or logQ_x')
    
    # assign low-frequency term hyperparameters to dictionary
    if np.isin('logw0', varnames):
        priors['logw0'] = (np.median(sho_trace['logw0']), np.std(sho_trace['logw0']), \
                           np.percentile(sho_trace['logw0'], 1), np.percentile(sho_trace['logw0'], 99))
    if np.isin('logSw4', varnames):
        priors['logSw4'] = (np.median(sho_trace['logSw4']), np.std(sho_trace['logSw4']), \
                            np.percentile(sho_trace['logSw4'], 1), np.percentile(sho_trace['logSw4'], 99))
    if np.isin('logQ', varnames):
        priors['logQ'] = (np.median(sho_trace['logQ']), np.std(sho_trace['logQ']), \
                          np.percentile(sho_trace['logQ'], 1), np.percentile(sho_trace['logQ'], 99))

    if np.isin('logw0_x', varnames):
        priors['logw0'] = (np.median(sho_trace['logw0_x']), np.std(sho_trace['logw0_x']), \
                           np.percentile(sho_trace['logw0_x'], 1), np.percentile(sho_trace['logw0_x'], 99))
    if np.isin('logSw4_x', varnames):
        priors['logSw4'] = (np.median(sho_trace['logSw4_x']), np.std(sho_trace['logSw4_x']), \
                            np.percentile(sho_trace['logSw4_x'], 1), np.percentile(sho_trace['logSw4_x'], 99))
    if np.isin('logQ_x', varnames):
        priors['logQ'] = (np.median(sho_trace['logQ_x']), np.std(sho_trace['logQ_x']), \
                          np.percentile(sho_trace['logQ_x'], 1), np.percentile(sho_trace['logQ_x'], 99))
            
    # assign high-frequency term hyperparameters to dictionary
    if np.isin('logS1', varnames):
        priors['logS1'] = (np.median(sho_trace['logS1']), np.std(sho_trace['logS1']), \
                           np.percentile(sho_trace['logS1'], 1), np.percentile(sho_trace['logS1'], 99))
    if np.isin('logS2', varnames):
        priors['logS2'] = (np.median(sho_trace['logS2']), np.std(sho_trace['logS2']), \
                           np.percentile(sho_trace['logS2'], 1), np.percentile(sho_trace['logS2'], 99))
    if np.isin('logQ1', varnames):
        priors['logQh'] = (np.median(sho_trace['logQ1']), np.std(sho_trace['logQ1']), \
                           np.percentile(sho_trace['logQ1'], 1), np.percentile(sho_trace['logQ1'], 99))
    
    # add fixed frequencies to dictionary
    if len(low_freqs) > 0:
        priors['f0'] = (low_freqs[0], 0., 0., 0.)
    
    for i, f in enumerate(high_freqs):
        priors['f'+str(i+1)] = (f, 0., 0., 0.)
        
     # fill in anything that is missing
    if ~np.isin('logw0', list(priors.keys())):
        if np.isin('f0', list(priors.keys())) & np.isin('logQ', list(priors.keys())):
            w0 = convert_frequency(2*pi*priors['f0'][0], T.exp(priors['logQ'][0]))
            priors['logw0'] = (np.log(float(w0.eval())), 0., 0., 0.)
        else:
            raise ValueError('Cannot determine logw0')
        
    if ~np.isin('logQ', list(priors.keys())):
        priors['logQ'] = (np.log(1/np.sqrt(2)), 0., 0., 0.)
    
    
    return priors

In [None]:
def cut_stamps(time, data, tts, dur, dtype, stampsize=1.5):
    """
    Cut out a stamp centered on each transit time from a full Kepler lightcurve

    Parameters
    ----------
    time : array-like
        time values at each cadence
    data : array-like
        corresponding data (e.g. flux, error, cadence number, etc.)
    dtype : string
        ndarray datatype to assign to output stamps
    tts : array-like
        list of transit times
    dur : float
        transit duration
    stampsize : float
        distance from each transit center to cut, in transit durations (default=1.5)
        
    Returns
    -------
    stamps : list
        list of stamps centered on each transit
    """
    stamps = []

    # cut out the stamps
    for t0 in tts:
        neartransit = np.abs(time - t0)/dur < stampsize
        stamps.append(np.array(data[neartransit], dtype=dtype))
        
    return stamps

In [None]:
def combine_stamps(sc_stamps, lc_stamps):
    '''
    Combine short and long cadence stamps, using SC wherever available

    Parameters
    ----------
    sc_stamps : list
        list of short cadence stamps
    lc_stamps: list
        list of long cadence stamps

    Returns
    -------
    stamps_out : list
        list of stamps
    stamp_cadence: array-like
        array of len(stamps_out) specifying cadence of each stamp as 'short', 'long', or 'none'
    '''
    # check lengths
    if len(sc_stamps) != len(lc_stamps):
        raise ValueError('inconsistent number of stamps')
        
    Nstamps = len(sc_stamps)
    
    # add stamps to list, prioritizing short cadence
    stamps_out = []
    stamp_cadence = []
    for i in range(Nstamps):
        if len(sc_stamps[i]) > 0:
            stamps_out.append(sc_stamps[i])
            stamp_cadence.append('short')
        elif len(lc_stamps[i]) > 0:
            stamps_out.append(lc_stamps[i])
            stamp_cadence.append('long')
        else:
            stamps_out.append([])
            stamp_cadence.append('none')
            
    stamp_cadence = np.array(stamp_cadence)

    return stamps_out, stamp_cadence

In [None]:
def ttv_lombscargle_analysis(tts, ephemeris, period):
    """
    Generates a Lomb-Scargle periodogram from a series of measured transit times and a linear ephemeris
    
    Parameters
    ----------
    tts : array-like
        vector of transit times
    ephemeris : array-like
        vector of times corresponding to the linear ephemeris
    period: float
        orbital period
        
    Returns
    -------
    LS : dict
        {freq , power, peak_freq, peak_fap}
        
    """
    # observed minus calculated
    omc = tts-ephemeris
    
    # Hann window to reduce ringing
    hann = sig.windows.hann(len(omc))
    hann /= np.sum(hann)
    
    # identify any egregious outliers
    local_trend = sig.medfilt(omc, kernel_size=7)
    
    out = np.zeros_like(tts)
    out = np.abs(omc - local_trend)/astropy.stats.mad_std(omc) > 5.0
    
            
    # compute a Lomb-Scargle periodogram; find 1st and 2nd peaks
    lombscargle = astropy.stats.LombScargle(tts[~out], omc[~out]*hann[~out])
    freq, power = lombscargle.autopower(minimum_frequency=2.0/(tts.max()-tts.min()), \
                                        maximum_frequency=0.25/period, \
                                        samples_per_peak=11)
    
    peak_freq  = freq[np.argmax(power)]
    peak_fap   = lombscargle.false_alarm_probability(power.max(), method='bootstrap')

    mask = np.abs(freq - peak_freq) < 11*np.median(freq[1:]-freq[:-1])
    second_peak = freq[~mask][np.argmax(power[~mask])]
    
    ls = {}
    ls['freq'] = freq
    ls['power'] = power
    ls['peak_freq'] = peak_freq
    ls['peak_fap'] = peak_fap
    ls['2nd_peak'] = second_peak
        
    return ls

In [None]:
import numpy as np
import scipy.optimize as op
import scipy.signal as sig
from   scipy import stats
import astropy
import sys
import os
import warnings

from .constants import *

__all__ = ['Planet']

class Planet:
    def __init__(self, epoch=None, period=None, depth=None, duration=None, index=None, tts=None, tts_err=None, quality=None, \
                 pttv=None, time_stamps=None, flux_stamps=None, error_stamps=None, mask_stamps=None, model_stamps=None, \
                 stamp_cadence=None, stamp_coverage=None, stamp_chisq=None, icov=None):


        self.epoch            = epoch            # reference transit time in range (0, period)
        self.period           = period           # orbital period
        self.depth            = depth            # transit depth
        self.duration         = duration         # transit duration

        self.index            = index            # index of each transit in range (0,1600) -- Kepler baseline
        self.tts              = tts              # all midtransit times in range (0,1600) -- Kepler baseline
        self.tts_err          = tts_err          # corresponding 1-sigma error bars on transit times
        self.quality          = quality          # boolean flag per transit; True=good

        self.pttv             = pttv             # [Amp, Pttv, phi, C0, C1, C2, C3]

        self.time_stamps      = time_stamps      # list of time stamps (one per transit) centered on midtransit
        self.flux_stamps      = flux_stamps      # list of flux stamps
        self.error_stamps     = error_stamps     # list of error stamps
        self.mask_stamps      = mask_stamps      # list of mask stamps, (mask=1 where OTHER planets transit)
        self.model_stamps     = model_stamps     # list of model stamps

        self.stamp_cadence    = stamp_cadence    # 'short', 'long', or 'none'
        self.stamp_coverage   = stamp_coverage   # fraction of transit/baseline covered by useable cadences
        self.stamp_chisq      = stamp_chisq      # chi-sq per transit

        self.icov             = icov             # inverse covariance matrix
        
        
        ###
        
        
    def mask_overlapping_transits(self):
        '''
        Remove cadences from stamps where other planets transit

        -- automatically updates time_, flux_, error_, mask_, and cadno_stamps
        '''
        for i, m in enumerate(self.mask_stamps):
            if len(m) > 0:
                self.time_stamps[i]  = self.time_stamps[i][~m]
                self.flux_stamps[i]  = self.flux_stamps[i][~m]
                self.error_stamps[i] = self.error_stamps[i][~m]
                self.mask_stamps[i]  = self.mask_stamps[i][~m]
                self.cadno_stamps[i] = self.cadno_stamps[i][~m]

        return None

    
    def clip_outlier_cadences(self, sigma=5.0, kernel_size=7):
        '''
        Do some iterative sigma rejection on each stamp

        sigma: rejection threshold for clipping (default=5.0)
        kernel_size: size of window for median filter (default=7)

        -- automatically updates time_, flux_, error_, mask_, and cadno_stamps
        '''
        for i, f in enumerate(self.flux_stamps):
            if len(f) > 0:
                loop = True
                while loop:
                    smoothed = sig.medfilt(self.flux_stamps[i], kernel_size=kernel_size)
                    outliers = np.abs(self.flux_stamps[i]-smoothed)/self.error_stamps[i] > sigma

                    if np.sum(outliers) > 0:
                        self.time_stamps[i]  = self.time_stamps[i][~outliers]
                        self.flux_stamps[i]  = self.flux_stamps[i][~outliers]
                        self.error_stamps[i] = self.error_stamps[i][~outliers]
                        self.cadno_stamps[i] = self.cadno_stamps[i][~outliers]
                    else:
                        loop = False

        return None


    def flatten_stamps(self, jitter=0.1):
        '''
        Fit a linear polynomial to out-of-transit flux to flatten data flux stamps

        jitter: fudge factor to avoid fitting in-transit flux if there are unresolved TTVs (default=0.1)

        -- automatically updates flux_stamps on Planet object
        '''
        for i, flux in enumerate(self.flux_stamps):
            if len(flux) > 0:
                time = self.time_stamps[i]

                intransit = np.abs(time-self.tts[i])/self.duration < 0.5+jitter

                if np.sum(~intransit) > 0:
                    coeffs = np.polyfit(time[~intransit],flux[~intransit],1)
                    linfit = np.polyval(coeffs, time)
                else:
                    linfit = 1.0

                self.flux_stamps[i] = flux/linfit

        return None



    def calculate_stamp_coverage(self, stampsize=1.5):
        '''
        Flag stamps with insufficient in-transit points

        stampsize: distance from each transit center to consider, in transit durations (default=1.5)
        '''
        # determine locations of SC and LC data
        sc_loc = self.stamp_cadence == 'short'
        lc_loc = self.stamp_cadence == 'long'

        # expected number of points in stamp if none are missing
        expected_sc_pts = 2*stampsize*self.duration/(SCIT/3600/24)
        expected_lc_pts = 2*stampsize*self.duration/(LCIT/60/24)

        # count up points per stamp overall
        pts_overall = []
        for t in self.time_stamps:
            pts_overall.append(len(t))
        pts_overall = np.array(pts_overall)

        # count up points per stamp in transit
        pts_in_transit = []
        for i, t0 in enumerate(self.tts):
            pts_in_transit.append(np.sum(np.abs(self.time_stamps[i]-t0) < self.duration/2))
        pts_in_transit = np.array(pts_in_transit)

        # calculate cover fraction        
        overall_fraction = np.zeros_like(self.tts)
        overall_fraction[sc_loc] = pts_overall[sc_loc]/expected_sc_pts
        overall_fraction[lc_loc] = pts_overall[lc_loc]/expected_lc_pts

        in_transit_fraction = np.zeros_like(self.tts)
        in_transit_fraction[sc_loc] = pts_in_transit[sc_loc]/(expected_sc_pts/2/stampsize)
        in_transit_fraction[lc_loc] = pts_in_transit[lc_loc]/(expected_lc_pts/2/stampsize)

        # use the smaller value as the coverage
        self.stamp_coverage = np.minimum(overall_fraction, in_transit_fraction)

        return None


    def calculate_stamp_chisq(self):
        '''
        Compare model_stamps, flux_stamps, and error_stamps to calcualte chisq for each transit
        '''
        mstamps = self.grab_stamps('model')
        fstamps = self.grab_stamps('flux')
        icov    = self.grab_icov()

        stamp_chisq = []
        j = 0
        for i, good in enumerate(self.quality):
            if good:
                y = mstamps[j]-fstamps[j]
                stamp_chisq.append(np.dot(y.T,np.dot(icov[j],y)))
                j += 1
            else:
                stamp_chisq.append(np.inf)

        self.stamp_chisq = np.array(stamp_chisq)

        return None


    def identify_good_transits(self, cover_fraction=0.7, chisq_sigma=5.0, verbose=True):
        '''
        Identify transits with sufficient coverage and non-outlier chisq

        cover_fraction: coverage threshold; eg. 0.7 will reject stamps with more than 70% of cadences missing (default=0.7)
        chisq_sigma: sigma threshold to reject stamps as poorly fit (default=5.0)
        verbose: boolean flag; 'True' to print results
        '''
        # determine locations of SC and LC data
        sc_loc = self.stamp_cadence == 'short'
        lc_loc = self.stamp_cadence == 'long'

        # flag stamps with sufficient coverage
        self.calculate_stamp_coverage()
        enough_pts = self.stamp_coverage > cover_fraction
        nonempty   = self.stamp_coverage > 0

        # count up points per stamp
        pts_per_stamp = []
        for t in self.time_stamps:
            pts_per_stamp.append(len(t))
        pts_per_stamp = np.array(pts_per_stamp)

        # flag stamps with unusually high chisq values (use pseudo-reduced-chisq)
        reject_chisq = np.zeros_like(self.tts, dtype='bool')

        if self.stamp_chisq is not None:
            X2u = self.stamp_chisq[~np.isinf(self.stamp_chisq)] / (pts_per_stamp[~np.isinf(self.stamp_chisq)])
            mad = astropy.stats.mad_std(X2u)
            med = np.median(X2u)
            reject_chisq[~np.isinf(self.stamp_chisq)] = np.abs(X2u-med)/mad > chisq_sigma
            reject_chisq[np.isinf(self.stamp_chisq)]  = True

        # print out results
        if verbose:
            print('%d out of %d transits rejected with high chisq' \
                  %(np.sum(reject_chisq[enough_pts]), np.sum(nonempty)))
            print('%d out of %d transits rejected with insufficient coverage' \
                  %(np.sum(~enough_pts[nonempty]), np.sum(nonempty)))

        # save the results
        self.quality = enough_pts * ~reject_chisq

        return None


    def grab_stamps(self, stamptype, cadence='any'):
        '''
        stamptype: 'time', 'flux', 'error', 'mask', 'model', or 'cadno'
        cadence: 'short', 'long', or 'any'
        '''
        if stamptype == 'time':  stamps = self.time_stamps
        if stamptype == 'flux':  stamps = self.flux_stamps
        if stamptype == 'error': stamps = self.error_stamps
        if stamptype == 'mask':  stamps = self.mask_stamps
        if stamptype == 'model': stamps = self.model_stamps
        if stamptype == 'cadno': stamps = self.cadno_stamps
      
        if cadence == 'any':
            use = self.quality * ~(self.stamp_cadence=='none')
        elif cadence == 'short':
            use = self.quality * (self.stamp_cadence=='short')
        elif cadence == 'long':
            use = self.quality * (self.stamp_cadence=='long')
        else:
            raise ValueError('cadence must be "short", "long", or "any"')

        stamps_out = []
        for i, s in enumerate(stamps):
            if use[i]: stamps_out.append(s)

        return stamps_out
    
    
    
    def grab_icov(self, cadence='any'):
        '''
        cadence: 'short', 'long', or 'any'
        '''
        if cadence == 'any':
            use = self.quality * ~(self.stamp_cadence=='none')
        elif cadence == 'short':
            use = self.quality * (self.stamp_cadence=='short')
        elif cadence == 'long':
            use = self.quality * (self.stamp_cadence=='long')
        else:
            raise ValueError('cadence must be "short", "long", or "any"')

        icov_out = []
        for i, c in enumerate(self.icov):
            if use[i]: icov_out.append(c)

        return icov_out

In [None]:
def notch_filter(data, f0, fsamp, Q):
    """
    Apply a 2nd-order notch filter (i.e. a narrow stopband filter) to a data array
    See scipy.signal.iirnotch & scipy.signal.filtfilt for details of implementation
    
    Parameters
    ----------
    data : array-like
        data to be filtered
    f0 : float
        center frequency of stopband
    fsamp: float
        sampling frequency, same units as f0
    Q : float
        quality factor
        
    Returns
    -------
    data_filtered: array-like
        data array with selcted frequency filtered out
    """
    w0 = f0/(fsamp/2)
    
    b, a = sig.iirnotch(w0, Q)
    
    data_filtered = sig.filtfilt(b, a, data)
    
    return data_filtered



def FFT_estimator(x, y, sigma=5.0):
    """
    Identify significant frequencies in a (uniformly sampled) data series
    Fits a Lorentzian around each peak
    
    Parameters
    ----------
    x : array-like
        1D array of x data values; should be monotonically increasing
    y : array-like
        1D array of corresponding y data values, len(x)
    sigma : float
        sigma threshold for selecting significant frequencies (default=5.0)
        
    Returns
    -------
    xf : ndarray
        1D array of frequency values
    yf : ndarray
        1D array of response values, len(xf)
    freqs : ndarray
        array of significant frequencies
    """
    # min/max testable time deltas (conservative low-freq cutoff)
    Tmin = 2*(x[1]-x[0])
    Tmax = (x.max()-x.min())/4

    N = len(x)//2

    # FFT convolved with a hann windown (to reduce spectral leakage)
    window = sig.hann(len(x))

    xf = np.linspace(0, 1/Tmin, N)
    yf = np.abs(fftpack.fft(window*y)[:N])
    
    yf -= np.median(yf)
        
    keep = xf > 1/Tmax
    
    xf = xf[keep]
    yf = yf[keep]
    
    yf = boxcar_smooth(yf, 3, 1)
    
    # make a copy of raw xf and yf data
    xf_all = xf.copy()
    yf_all = yf.copy()

    
    # now search for significant frequencies
    freqs = []

    loop = True
    while loop:
        yf_noise = astropy.stats.mad_std(yf)
        peakfreq = xf[np.argmax(yf)]
        
        if (yf[xf==peakfreq]/yf_noise > sigma) and (yf[xf==peakfreq] > 1/xf[xf==peakfreq]):
            res_fxn = lambda theta, x, y: y - lorentzian(theta, x)
            
            theta_in = np.array([peakfreq, 1/Tmax, yf.max(), np.median(yf)])
            theta_out, success = op.leastsq(res_fxn, theta_in, args=(xf, yf))

            width = np.max(5*[theta_out[1], 3*(xf[1]-xf[0])])
            mask = np.abs(xf-theta_out[0])/width < 1

            yf[mask] = theta_out[3]

            freqs.append(theta_out[0])

        else:
            loop = False

        
    freqs = np.array(freqs)    
    
    return xf_all, yf_all, freqs