In [101]:
import os
import pickle
import pandas as pd
import scipy.interpolate as interpolate
from scipy.stats import bootstrap
import random

from astroquery.gaia import Gaia
Gaia.MAIN_GAIA_TABLE = "gaiadr3.gaia_source"
import astropy.units as un


from gaia_dr3_photometric_uncertainties import EDR3_Photometric_Uncertainties


u = EDR3_Photometric_Uncertainties.Edr3LogMagUncertainty('./gaia_dr3_photometric_uncertainties/LogErrVsMagSpline.csv')

gCalSamples = pd.read_csv('./calSamples/gCalSamples.csv', header=None)
gCalDistSamples = pd.read_csv('./calSamples/gDistCalSamples.csv', header=None)

bCalSamples = pd.read_csv('./calSamples/bCalSamples.csv', header=None)
bCalDistSamples = pd.read_csv('./calSamples/bDistCalSamples.csv', header=None)

rCalSamples = pd.read_csv('./calSamples/rCalSamples.csv', header=None)
rCalDistSamples = pd.read_csv('./calSamples/rDistCalSamples.csv', header=None)


def calculate(group, rv=-1):
    calcVarG(group)
    calcVarBP(group)
    calcVarRP(group)
    
    return calcVar90(group, rv)

def age(group, distance=True, band='overall', rv=-1):
    varGage, varGageErr, varBPage, varBPageErr, varRPage, varRPageErr = var90Age(group, distance=distance, rv=rv)
    bestAge, bestAgeErr = combineAge(varGage, varGageErr, varBPage, varBPageErr, varRPage, varRPageErr)
    
    if band == 'G': return varGage, varGageErr
        
    if band == 'RP': return varRPage, varRPageErr
        
    if band == 'BP': return varBPage, varBPageErr
    
    if band == 'overall': return bestAge, bestAgeErr
    
    return  bestAge, bestAgeErr, varGage, varGageErr, varRPage, varRPageErr, varBPage, varBPageErr
    

############### HELPER FUNCTIONS BELOW ####################

def filters(group):
    
    # snr cuts 
    bpsnrcut = group['bpsnr'] > 20
    rpsnrcut = group['rpsnr'] > 20
    gsnrcut = group['gsnr'] > 30
    

    # plx cut
    plxcut = group['plx'] / group['plx_err'] > 20
    
    
    # white dwarf cut
    Mg = group['gmags'] - 5*(np.log10(group['dist']) - 1)
    wd = Mg < 10
    wd2 = group['bp-rp'] > 1
    wdcut = wd + wd2 # signals a not WD
    
    # color cut
    colorcut = group['bp-rp'] < 2.5
    
    probabilityCut = group['Probability'] > .5
    
    
    return bpsnrcut * rpsnrcut * gsnrcut  * plxcut  * wdcut * colorcut * probabilityCut



def gaiaQuery(file, rewrite=False):
    filename = os.path.split(file)[-1].split(".")[0]
    if os.path.exists(f"./{filename}_gaiaResults.pkl") and not rewrite:
        print("Gaia query results already exist for this file. If you would like to override the previous query, rerun with parameter 'rewrite=True'")
        return pd.read_pickle(f"./{filename}_gaiaResults.pkl")
    
    try:
        FFinfo = pd.read_csv(file)[['Gaia DR3', 'Vr(pred)','Vr(obs)']]
         
    except:
        try: 
            FFinfo = pd.read_csv(file)[['ra', 'dec_x', 'Probability']]
        except:
            print(f"{file} not found or is in the wrong format")
            return None
    
    bp_rp = []

    gmags = []
    gfluxs = []
    gfluxerrs = []
    gnums = []


    bmags = []
    bfluxs = []
    bfluxerrs = []
    bnums = []


    rmags = []
    rfluxs = []
    rfluxerrs = []
    rnums = []


    ras = []
    decs = []
    pmras = []
    pmdecs = []
    ids = []

    dist = []
    bpsnr = []
    rpsnr = []
    gsnr = []
    ruwe = []
    plx = []
    plxerr = []
    
    rvObs = []
    rvPred = []
    
    probs = []


    for star in FFinfo.values:
        try:
            r = Gaia.cone_search(f'{star[0]} {star[1]}', radius = 5*un.arcsec, table_name = "gaiadr3.gaia_source")
            r = r.get_results()
            
            # grabs target closest to the search location
            i = np.where(r['dist'] == min(r['dist']))[0][0]
            
            if int(r[i]['phot_bp_n_obs']) != 0 and int(r[i]['phot_rp_n_obs']) != 0:
            
                # put everything in the tables
                bp_rp.append(float(r[i]['bp_rp']))


                gmags.append(float(r[i]['phot_g_mean_mag']))        
                gfluxs.append(float(r[i]['phot_g_mean_flux']))
                gfluxerrs.append(float(r[i]['phot_g_mean_flux_error']))
                gnums.append(int(r[i]['phot_g_n_obs']))


                bmags.append(float(r[i]['phot_bp_mean_mag']))
                bfluxs.append(float(r[i]['phot_bp_mean_flux']))
                bfluxerrs.append(float(r[i]['phot_bp_mean_flux_error']))
                bnums.append(int(r[i]['phot_bp_n_obs']))

                rmags.append(float(r[i]['phot_rp_mean_mag']))
                rfluxs.append(float(r[i]['phot_rp_mean_flux']))
                rfluxerrs.append(float(r[i]['phot_rp_mean_flux_error']))
                rnums.append(int(r[i]['phot_rp_n_obs']))


                ras.append(float(r[i]['ra']))
                decs.append(float(r[i]['dec']))
                pmras.append(float(r[i]['pmra']))
                pmdecs.append(float(r[i]['pmdec']))


                dist.append(1000/float(r[i]['parallax']))
                bpsnr.append(float(r[i]['phot_bp_mean_flux_over_error']))
                rpsnr.append(float(r[i]['phot_rp_mean_flux_over_error']))
                gsnr.append(float(r[i]['phot_g_mean_flux_over_error']))
                ruwe.append(float(r[i]['ruwe']))
                plx.append(float(r[i]['parallax']))
                plxerr.append(float(r[i]['parallax_error']))

                ids.append(f'Gaia DR3 {r[i]["source_id"]}')
                
                try:
                    rvObs.append(0)
                    rvPred.append(0)
                    probs.append(star[2])
                except:
                    rvObs.append(0)
                    rvPred.append(0)
                    probs.append(star[2])

        except:
            continue
            
            
    headers = ['Gaia DR3', 'ra', 'dec', 'dist', 'plx', 'plx_err', 'pmra', 'pmdec', 'Vr(obs)', 'Vr(pred)', 'ruwe',  'gmags', 'gfluxs', 'gfluxerrs', 'gnums', 'bpmags', 'bpfluxs', 'bpfluxerrs', 'bpnums', 'rpmags', 'rpfluxs', 'rpfluxerrs', 'rpnums', 'bp-rp', 'bpsnr', 'rpsnr', 'gsnr', 'Probability']
    data =    [ids       ,  ras,  decs,   dist,   plx,    plxerr,  pmras,  pmdecs,   rvObs, rvPred, ruwe,    gmags,   gfluxs,   gfluxerrs,   gnums,    bmags,    bfluxs,    bfluxerrs,    bnums,    rmags,    rfluxs,    rfluxerrs,    rnums,   bp_rp,   bpsnr,   rpsnr,   gsnr, probs]

    tempDict = {}
    for i in range(len(headers)):
        tempDict[headers[i]] = data[i]


    df = pd.DataFrame(tempDict)
    
    pickle.dump(df, open(f'./{filename}_gaiaResults.pkl', 'wb'))
    
    print(f"Gaia query results saved to './{filename}_gaiaResults.pkl'")
    
    return df



def calcVarG(group):
    
    nobs = np.arange(np.min(group['gnums'])-1,np.max(group['gnums'])+1,1)
    gn = u.estimate('g',nobs=nobs)
    mag = gn['mag_g']
    
    sig_est = np.zeros(np.size(group['gmags']))


    for i in np.arange(0,np.size(group['gmags'])):
        est = gn[f'logU_{group["gnums"][i]:d}']
        f = interpolate.interp1d(mag,est,fill_value="extrapolate")
        sig_est[i] = (f(group['gmags'][i]))
    
    try:
        group.insert(len(group.columns), "log10(sigma_g_nobs)", sig_est)
    except:
        for v in range(len(group["log10(sigma_g_nobs)"])):
            group.__getitem__("log10(sigma_g_nobs)").__setitem__(v, sig_est[v])
    
    varindx = np.log10(((2.5/np.log(10))*group['gfluxerrs'] / group['gfluxs'])) - (group['log10(sigma_g_nobs)'])
    
    try:
        group.insert(len(group.columns), "varG", varindx)
    except:
        for v in range(len(group["varG"])):
            group.__getitem__("varG").__setitem__(v, varindx[v])
    
    return



def calcVarRP(group):
    
    nobs = np.arange(np.min(group['rpnums'])-1,np.max(group['rpnums'])+1,1)
    gn = u.estimate('rp',nobs=nobs)
    mag = gn['mag_rp']
    
    sig_est = np.zeros(np.size(group['rpmags']))


    for i in np.arange(0,np.size(group['rpmags'])):
        est = gn[f'logU_{group["rpnums"][i]:d}']
        f = interpolate.interp1d(mag,est,fill_value="extrapolate")
        sig_est[i] = (f(group['rpmags'][i]))
    
    try: 
        group.insert(len(group.columns), "log10(sigma_rp_nobs)", sig_est)
    except:
        for v in range(len(group["log10(sigma_rp_nobs)"])):
            group.__getitem__("log10(sigma_rp_nobs)").__setitem__(v, sig_est[v])
    
    varindx = np.log10(((2.5/np.log(10))*group['rpfluxerrs'] / group['rpfluxs'])) - (group['log10(sigma_rp_nobs)'])
    
    try:
        group.insert(len(group.columns), "varRP", varindx)
    except:
        for v in range(len(group["varRP"])):
            group.__getitem__("varRP").__setitem__(v, varindx[v])
    
    return


def calcVarBP(group):
    
    nobs = np.arange(np.min(group['bpnums'])-1,np.max(group['bpnums'])+1,1)
    gn = u.estimate('bp',nobs=nobs)
    mag = gn['mag_bp']
    
    sig_est = np.zeros(np.size(group['bpmags']))


    for i in np.arange(0,np.size(group['bpmags'])):
        est = gn[f'logU_{group["bpnums"][i]:d}']
        f = interpolate.interp1d(mag,est,fill_value="extrapolate")
        sig_est[i] = (f(group['bpmags'][i]))
    
    try:
        group.insert(len(group.columns), "log10(sigma_bp_nobs)", sig_est)
    except:
        for v in range(len(group["log10(sigma_bp_nobs)"])):
            group.__getitem__("log10(sigma_bp_nobs)").__setitem__(v, sig_est[v])
    
    varindx = np.log10(((2.5/np.log(10))*group['bpfluxerrs'] / group['bpfluxs'])) - (group['log10(sigma_bp_nobs)'])
    
    try:
        group.insert(len(group.columns), "varBP", varindx)
    except:
        for v in range(len(group["varBP"])):
            group.__getitem__("varBP").__setitem__(v, varindx[v])
    
    return



def ninety(arr,axis=None):
    return np.nanpercentile(arr,90,axis=axis)


def calcVar90(group, rv = -1):
    
    
    
    if rv != -1:
        rvCut = abs(group['Vr(obs)'] - group['Vr(pred)'] ) < rv
    else:
        rvCut = abs(group['Vr(pred)']) >= 0
        
        
    Gperc90 = np.nanpercentile(group[f"varG"][filters(group) * rvCut ], 90)
    Gper90Err = (bootstrap((group[f"varG"][filters(group) * rvCut ],),ninety)).standard_error

    Bperc90 = np.nanpercentile(group[f"varBP"][filters(group)* rvCut ], 90)
    Bper90Err = (bootstrap((group[f"varBP"][filters(group)* rvCut ],),ninety)).standard_error

    Rperc90 = np.nanpercentile(group[f"varRP"][filters(group)* rvCut ], 90)
    Rper90Err = (bootstrap((group[f"varRP"][filters(group)* rvCut ],),ninety)).standard_error
    
    return Gperc90, Gper90Err, Bperc90, Bper90Err, Rperc90, Rper90Err



def var90Age(group, distance=True, rv=-1):
    
    iterations = 10000
    distVal = 0
    
    Gperc90, Gper90Err, Bperc90, Bper90Err, Rperc90, Rper90Err = calcVar90(group, rv=rv)
    
    if type(distance) != bool:
        if distance == None:
            distance = True
        elif type(distance) == float or type(distance) == int:
            distVal = distance
            distance = True
        
    
    
    if not distance:
        
        # G-band
        trials = []
    
        for i in range(iterations):

            randNum = random.randint(0,len(gCalSamples)-1)

            mSample, bSample, fSample = gCalSamples.iloc[randNum]
            perSample = np.random.normal(Gperc90, Gper90Err)

            logageSample = perSample * mSample + bSample

            trials.append(logageSample)

        trials = np.random.normal(trials,np.random.normal(0.178,0.024,np.size(trials)))


        medA = np.median(trials)
        stdA = np.std(trials)
        medAges = [10**medA, np.asarray([(10**medA - 10**(medA-stdA)), (10**(medA+stdA) - 10**medA)])]

        a = 10**np.percentile(trials, [16, 50, 84])
        b = np.diff(a)
        c = 10**np.percentile(trials, 50)
        
        varGage, varGageErr = c, b
        
        
        
        # Bp-band
        trials = []

        for i in range(iterations):
            randNum = random.randint(0,len(bCalSamples)-1)

            mSample, bSample, fSample = bCalSamples.iloc[randNum]
            perSample = np.random.normal(Bperc90, Bper90Err)

            logageSample = perSample * mSample + bSample

            trials.append(logageSample)

        trials = np.random.normal(trials,np.random.normal(0.177,0.025,np.size(trials)))


        medA = np.median(trials)
        stdA = np.std(trials)
        medAges = [10**medA, np.asarray([(10**medA - 10**(medA-stdA)), (10**(medA+stdA) - 10**medA)])]

        a = 10**np.percentile(trials, [16, 50, 84])
        b = np.diff(a)
        c = 10**np.percentile(trials, 50)

        varBPage, varBPageErr = c, b
        
        
        
        # RP-band
        trials = []

        for i in range(iterations):
            randNum = random.randint(0,len(rCalSamples)-1)

            mSample, bSample, fSample = rCalSamples.iloc[randNum]
            perSample = np.random.normal(Rperc90, Rper90Err)

            logageSample = perSample * mSample + bSample

            trials.append(logageSample)

        trials = np.random.normal(trials,np.random.normal(0.141,0.025,np.size(trials)))

        
        medA = np.median(trials)
        stdA = np.std(trials)
        medAges = [10**medA, np.asarray([(10**medA - 10**(medA-stdA)), (10**(medA+stdA) - 10**medA)])]

        a = 10**np.percentile(trials, [16, 50, 84])
        b = np.diff(a)
        c = 10**np.percentile(trials, 50)

        varRPage, varRPageErr = c, b

    
    
    
    
    if distance:
        if distVal == 0:
            distVal = np.nanmedian(group['dist'][filters(group)])
        
        print(distVal)
        
        # G-Band
        trials = []

        for i in range(iterations):
            randNum = random.randint(0,len(gCalDistSamples)-1)

            mSample, bSample, fSample, dSample = gCalDistSamples.iloc[randNum]
            perSample = np.random.normal(Gperc90, Gper90Err)

            logageSample = perSample * mSample + dSample * distVal + bSample

            trials.append(logageSample)

        trials = np.random.normal(trials,abs(np.random.normal(0.155,0.025,np.size(trials))))

        medA = np.median(trials)
        stdA = np.std(trials)
        medAges = [10**medA, np.asarray([(10**medA - 10**(medA-stdA)), (10**(medA+stdA) - 10**medA)])]

        a = 10**np.percentile(trials, [16, 50, 84])
        b = np.diff(a)
        c = 10**np.percentile(trials, 50)

        varGage, varGageErr = c, b
        
        
        # BP-band
        trials = []

        for i in range(iterations):
            randNum = random.randint(0,len(bCalDistSamples)-1)

            mSample, bSample, fSample, dSample = bCalDistSamples.iloc[randNum]
            perSample = np.random.normal(Bperc90, Bper90Err)

            logageSample = perSample * mSample + dSample * distVal + bSample

            trials.append(logageSample)

        trials = np.random.normal(trials,abs(np.random.normal(0.129,0.024,np.size(trials))))

        medA = np.median(trials)
        stdA = np.std(trials)
        medAges = [10**medA, np.asarray([(10**medA - 10**(medA-stdA)), (10**(medA+stdA) - 10**medA)])]

        a = 10**np.percentile(trials, [16, 50, 84])
        b = np.diff(a)
        c = 10**np.percentile(trials, 50)

        varBPage, varBPageErr = c, b
        
        
        # RP-band
        trials = []

        for i in range(iterations):
            randNum = random.randint(0,len(rCalDistSamples)-1)

            mSample, bSample, fSample, dSample = rCalDistSamples.iloc[randNum]
            perSample = np.random.normal(Rperc90, Rper90Err)

            logageSample = perSample * mSample + dSample * distVal + bSample

            trials.append(logageSample)

        trials = np.random.normal(trials,abs(np.random.normal(0.089,0.023,np.size(trials))))

        medA = np.median(trials)
        stdA = np.std(trials)
        medAges = [10**medA, np.asarray([(10**medA - 10**(medA-stdA)), (10**(medA+stdA) - 10**medA)])]

        a = 10**np.percentile(trials, [16, 50, 84])
        b = np.diff(a)
        c = 10**np.percentile(trials, 50)

        varRPage, varRPageErr = c, b
    
    
    
    return varGage, varGageErr, varBPage, varBPageErr, varRPage, varRPageErr



def combineAge(varGage, varGageErr, varBPage, varBPageErr, varRPage, varRPageErr):
     
    # using lower age error     
    wAvg0 = np.average( np.asarray([varGage, varBPage, varRPage]), weights = np.asarray([1/(varGageErr[0])**2, 1/(varBPageErr[0])**2, 1/(varRPageErr[0])**2]))
    wErr0 = np.sqrt(1/sum(np.asarray([1/(varGageErr[0])**2, 1/(varBPageErr[0])**2, 1/(varRPageErr[0])**2])))

    # using upper age error
    wAvg1 = np.average( np.asarray([varGage, varBPage, varRPage]), weights = np.asarray([1/(varGageErr[1])**2, 1/(varBPageErr[1])**2, 1/(varRPageErr[1])**2]),)
    wErr1 = np.sqrt(1/sum(np.asarray([1/(varGageErr[1])**2, 1/(varBPageErr[1])**2, 1/(varRPageErr[1])**2])))

    # combining
    wAvg = np.average([wAvg0, wAvg1])
    wErr = np.array(list(zip([wErr0], [wErr1]))).T
    
    return wAvg, [wErr[0][0], wErr[1][0]]
   


In [50]:
filepath = "./tic_gaia_banyan.csv"
group = gaiaQuery(filepath, rewrite=False)

Gaia query results already exist for this file. If you would like to override the previous query, rerun with parameter 'rewrite=True'


In [51]:
varG90, varG90Err, varBP90, varBP90Err, varRP90, varRP90Err = calculate(group)

In [102]:
age(group, band = 'all', distance = True)

54.73585946043388


(66.61913204931595,
 [11.799026950556028, 16.006755302677806],
 109.43271902374833,
 array([35.46102801, 52.26893435]),
 70.62692792577519,
 array([18.70501152, 23.79720207]),
 53.75052667355897,
 array([16.83194754, 23.76175369]))

In [100]:
np.median(group['dist'][filters(group)])

54.73585946043388

In [97]:
group[filters(group)]

Unnamed: 0,Gaia DR3,ra,dec,dist,plx,plx_err,pmra,pmdec,Vr(obs),Vr(pred),...,bpsnr,rpsnr,gsnr,Probability,log10(sigma_g_nobs),varG,log10(sigma_bp_nobs),varBP,log10(sigma_rp_nobs),varRP
1,Gaia DR3 4899996487129314688,0.931330,-65.779834,33.954993,29.450749,0.015428,137.940879,-59.435737,0,0,...,829.804626,1520.055420,2612.727783,0.781425,-3.806932,0.425562,-3.244133,0.360881,-3.528009,0.381874
7,Gaia DR3 2860458423979097984,4.294044,29.530688,54.759045,18.261823,0.027096,76.248929,-90.545502,0,0,...,396.244202,141.359207,1829.668823,0.586712,-3.662876,0.436228,-3.324192,0.761953,-3.544974,1.430374
15,Gaia DR3 2323761887650523520,6.133932,-25.381822,42.993302,23.259437,0.023742,106.873573,-100.390318,0,0,...,299.334747,810.964966,1465.331421,0.933957,-3.686640,0.556428,-3.088305,0.647872,-3.403623,0.530345
16,Gaia DR3 2323761887650523520,6.133932,-25.381822,42.993302,23.259437,0.023742,106.873573,-100.390318,0,0,...,299.334747,810.964966,1465.331421,0.990005,-3.686640,0.556428,-3.088305,0.647872,-3.403623,0.530345
18,Gaia DR3 2855866348024648320,6.861575,26.278478,61.312400,16.309915,0.024643,70.914920,-71.492128,0,0,...,3049.841064,1857.980469,5365.693848,0.989016,-3.750950,0.057048,-3.201149,-0.247404,-3.288748,0.055431
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
979,Gaia DR3 1920997794711376640,357.463339,39.938256,41.689146,23.987059,0.024531,98.249784,-98.146203,0,0,...,470.670837,767.427551,1452.870239,0.982563,-3.812932,0.686429,-3.326153,0.689160,-3.643866,0.794553
980,Gaia DR3 1926214668148933120,357.603057,44.926109,59.560622,16.789617,0.018948,68.429902,-59.155818,0,0,...,295.264862,632.155884,1266.771240,0.952778,-3.758215,0.691241,-3.111817,0.677330,-3.486896,0.721796
985,Gaia DR3 6488114422112039168,358.668394,-60.859715,42.464002,23.549358,0.011450,109.474270,-61.742751,0,0,...,1324.040649,2419.980957,5275.288574,0.934413,-3.847066,0.160544,-3.433662,0.347485,-3.619154,0.271066
986,Gaia DR3 2847644410526748416,358.861797,22.192629,25.474116,39.255533,0.022137,202.046341,-147.050252,0,0,...,1372.395996,1022.114929,4973.564941,0.875483,-3.808918,0.147975,-3.463351,0.361596,-3.550501,0.576726
