In [1]:
import numpy as np
from astropy.table import Table
from astropy import table
from astroquery.vizier import Vizier
from astroquery.gaia import Gaia
from astroquery.mast import Catalogs
from astropy.coordinates import SkyCoord
from astropy.io.votable import parse
from astropy.constants import G, M_sun, pc, R_sun, h, c, k_B
# from dust_extinction.parameter_averages import G23
import extinction
import astropy.units as u
from scipy.optimize import curve_fit, bisect
from scipy.interpolate import LinearNDInterpolator, interp1d
import os.path
from os import listdir
import matplotlib.pyplot as plt
from csv import writer
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')

figwidth_single = 89/25.4  # [inch]
figwidth_double = 183/25.4  # [inch]
%matplotlib tk

# Functions
* Yarin- run all 3 cells

In [2]:
## filter zero points, effective wavelengths, column names

## f = f0 * 10^(-0.4*(m-m0)) (erg/s/cm^2/AA) 

## wise_f0 = (np.array([309.540,171.787,31.674]) * u.Jy.to(u.erg/u.s/u.cm**2/u.Hz) * u.erg/u.s/u.cm**2/u.Hz * c.cgs / (np.array([33526,46030,115608]) * u.AA)**2).to(u.erg/u.s/u.cm**2/u.AA)

## gaia_f0 = np.array([3.009167e-21,1.346109e-21,1.638483e-21]) * (u.W/u.m**2/u.nm).to(u.erg/u.s/u.cm**2/u.AA)

# two_mass_f0 = np.array([3.129e-13,1.133e-13,4.283e-14]) * (u.W / u.cm**2 / u.micron).to(u.erg/u.s/u.cm**2/u.AA)


twomass_zp_table = Table({'band': ['2MASS.J','2MASS.H', '2MASS.Ks'],
                        'col':['Jmag','Hmag','Kmag'],
                        'err_col':['e_Jmag','e_Hmag','e_Kmag'],
                        'lambda_eff':[12350,16620,21590],
                        'width':[1620,2510,2640],
                        'f0':[3.129e-10,1.133e-10,4.283e-11],
                        'm0':[0,0,0],
                        'wd_f0':[3.106e-10,1.143e-10,4.206e-11], ## bedard f0 different than VOSA
                        'wd_lambda_eff':[12350,16620,21590]}) 

wise_zp_table = Table({'band': ['WISE.W1','WISE.W2','WISE.W3'],
                       'col':['W1mag','W2mag','W3mag'],
                          'err_col':['e_W1mag','e_W2mag','e_W3mag'],
                          'lambda_eff':[33526,46028,115608],
                          'width':[6620,10422,55055],
                                'f0':[8.18e-12,2.42e-12,6.52e-14],
                                'm0':[0,0,0],
                                'wd_f0':[8.18e-12,2.42e-12,6.52e-14],
                                'wd_lambda_eff':[33526,46028,115608]}) ## bedard doesn't give WISE f0

gaia_zp_table = Table({'band': ['GAIA3.Gbp','GAIA3.G','GAIA3.Grp'],
                       'col':['phot_bp_mean_mag','phot_g_mean_mag','phot_rp_mean_mag'],
                       'err_col':['phot_bp_mean_flux_error','phot_g_mean_flux_error','phot_rp_mean_flux_error'],
                       'lambda_eff':[5035,5822,7619],
                       'width':[2333,4203,2842],
                       'f0':[4.08e-9,2.5e-9,1.27e-9],
                       'm0':[0,0,0],
                       'wd_f0':[4.08e-9,2.5e-9,1.27e-9], ## bedard doesn't give GAIA f0
                       'wd_lambda_eff':[5035,5822,7619]})

galex_zp_table = Table({'band': ['GALEX.FUV', 'GALEX.NUV'],
                       'col':['fuv_mag','nuv_mag'],
                       'err_col':['fuv_magerr','nuv_magerr'],
                       'lambda_eff':[1548,2303],
                           'width':[265,770],
                       'f0':[4.6e-8,2.05e-8],
                       'm0':[0,0],
                       'wd_f0':[4.6e-8,2.05e-8], ## bedard doesn't give GALEX f0- its AB anyway
                       'wd_lambda_eff':[1548,2303]})

synt_zp_table = Table({'band': ['Johnson.U', 'Johnson.B','Johnson.V','Johnson.R','Johnson.I',
                             'SDSS.u','SDSS.g','SDSS.r','SDSS.i','SDSS.z'],
                     'lambda_eff':[3551.05,4369.53,5467.57,6695.83,8568.89,3608.04,4671.78,6141.12,7457.89,8922.78],
                     'width':[657,972,889,2070,2316,541,1064,1055,1102,1164],
                    'col':['u_jkc_flux','b_jkc_flux','v_jkc_flux','r_jkc_flux','i_jkc_flux',
                           'u_sdss_flux','g_sdss_flux','r_sdss_flux','i_sdss_flux','z_sdss_flux'],
                  'f0':[3.49719e-9,6.72553e-9,3.5833e-9,1.87529e-9,9.23651e-10,
                        3.75079e-9,5.45476e-9,2.49767e-9,1.38589e-9,8.38585e-10],
                        'm0':[0,0,0,0,0,0,0,0,0,0],
                        'wd_f0':[3.684e-9,6.548e-9,3.804e-9,2.274e-9,1.119e-9,
                                 1.1436e-8,4.9894e-9,2.8638e-9,1.9216e-9,1.3343e-9],
                        'wd_lambda_eff':[3971,4491,5423,6441,8071,
                                         3146,4670,6156,7471,8918]}) ## bedard f0 different than VOSA
sdss_zp_table = Table({'band': ['SDSS.u','SDSS.g','SDSS.r','SDSS.i','SDSS.z'],
                      'lambda_eff':[3608.04,4671.78,6141.12,7457.89,8922.78],
                      'width':[541,1064,1055,1102,1164],
                      'col':['u_sdss_flux','g_sdss_flux','r_sdss_flux','i_sdss_flux','z_sdss_flux'],
                      'f0':[3.75079e-9,5.45476e-9,2.49767e-9,1.38589e-9,8.38585e-10],
                      'm0':[0,0,0,0,0],
                      'wd_f0':[1.1436e-8,4.9894e-9,2.8638e-9,1.9216e-9,1.3343e-9],
                      'wd_lambda_eff':[3146,4670,6156,7471,8918]})
jkc_zp_table = Table({'band': ['Johnson.U', 'Johnson.B','Johnson.V','Johnson.R','Johnson.I'],
                      'lambda_eff':[3551.05,4369.53,5467.57,6695.83,8568.89],
                      'width':[657,972,889,2070,2316],
                      'col':['u_jkc_flux','b_jkc_flux','v_jkc_flux','r_jkc_flux','i_jkc_flux'],
                    'f0':[3.49719e-9,6.72553e-9,3.5833e-9,1.87529e-9,9.23651e-10],
                          'm0':[0,0,0,0,0],
                          'wd_f0':[3.684e-9,6.548e-9,3.804e-9,2.274e-9,1.119e-9],
                          'wd_lambda_eff':[3971,4491,5423,6441,8071]})

synt_zp_table = jkc_zp_table
# synt_zp_table = sdss_zp_table

bands_table = Table({'band': list(gaia_zp_table['band'])+list(wise_zp_table['band'])+list(twomass_zp_table['band'])
                     +list(synt_zp_table['band']) + list(galex_zp_table['band']),
                     'lambda_eff': list(gaia_zp_table['lambda_eff'])+list(wise_zp_table['lambda_eff'])+list(twomass_zp_table['lambda_eff'])
                     +list(synt_zp_table['lambda_eff']) + list(galex_zp_table['lambda_eff']),
                     'f0': list(gaia_zp_table['f0'])+list(wise_zp_table['f0'])+list(twomass_zp_table['f0'])
                     + list(synt_zp_table['f0']) + list(galex_zp_table['f0']),
                     'm0': list(gaia_zp_table['m0'])+list(wise_zp_table['m0'])+list(twomass_zp_table['m0']) + list(synt_zp_table['m0']) + list(galex_zp_table['m0']),
                     'wd_f0': list(gaia_zp_table['wd_f0'])+list(wise_zp_table['wd_f0'])+list(twomass_zp_table['wd_f0']) + list(synt_zp_table['wd_f0']) + list(galex_zp_table['wd_f0']),
                     'wd_lambda_eff': list(gaia_zp_table['wd_lambda_eff'])+list(wise_zp_table['wd_lambda_eff'])
                     +list(twomass_zp_table['wd_lambda_eff']) + list(synt_zp_table['wd_lambda_eff']) + list(galex_zp_table['wd_lambda_eff']),
                     'width': list(gaia_zp_table['width'])+list(wise_zp_table['width'])+list(twomass_zp_table['width'])
                       + list(synt_zp_table['width']) + list(galex_zp_table['width'])})

Table.sort(bands_table, 'lambda_eff')

# bands_table.add_column(['FUV','NUV','U','u','B','g','G3_BP','V','G3','r','R','i','G3_RP','I','z','J','H','Ks','W1','W2','W3'],name='wd_band')
bands_table.add_column(['FUV','NUV','U','B','G3_BP','V','G3','R','G3_RP','I','J','H','Ks','W1','W2','W3'],name='wd_band')
# bands_table.add_column(['FUV','NUV','u','g','G3_BP','G3','r','i','G3_RP','z','J','H','Ks','W1','W2','W3'],name='wd_band')

In [3]:
## get photometry

def get_synthetic_photometry(source_table):
    ## query gaia synthetic photometry for SDSS ugriz and Johnson UBVRI, convert to erg/s/cm^2/AA
    id_lst = tuple(map(int, source_table['source_id']))
    cols_to_query =','.join(['source_id'] + [c + ',' + c +'_error' + ',' + c.replace('flux','flag') for c in synt_zp_table['col']])
    query = 'SELECT ' + cols_to_query + f' FROM gaiadr3.synthetic_photometry_gspc WHERE source_id IN {id_lst}'
    job = Gaia.launch_job(query)
    result = job.get_results()
    for col in synt_zp_table['col']:
        band = synt_zp_table[synt_zp_table['col']==col]['band'][0]
        err_col = col + '_error'
        flag_col = col.replace('flux','flag')
        if band.startswith('Johnson'): ## Johnson fluxes given in W/s/m^2/nm
            result[col] = result[col].to(u.erg/u.s/u.cm**2/u.AA)
            result[err_col] = result[err_col].to(u.erg/u.s/u.cm**2/u.AA)
        elif band.startswith('SDSS'): ## SDSS fluxes given in W/s/m^2/Hz
            wl = synt_zp_table[synt_zp_table['col']==col]['lambda_eff'][0] * u.AA
            result[col] = (result[col].data * u.W / u.m**2 / u.Hz * c / wl.si**2).to(u.erg/u.s/u.cm**2/u.AA)
            result[err_col] = (result[err_col].data * u.W / u.m**2 / u.Hz * c / wl.si**2).to(u.erg/u.s/u.cm**2/u.AA)
        
        flagged = result[flag_col].astype(bool) # flag = 1 OK, flag = 0 bad
        result[col][~flagged] = np.nan
        result[err_col][~flagged] = np.nan
        
        result[col].name = band
        result[err_col].name = band +'_err'

    if 'SOURCE_ID' in result.colnames:
        result.rename_column('SOURCE_ID','source_id')
    result.keep_columns(['source_id'] + list(synt_zp_table['band']) + list(band + '_err' for band in synt_zp_table['band']))
    if len(result) == 0:
        sid = source_table['source_id'][0]
        result.add_row([sid if col == 'source_id' else np.nan for col in result.colnames])
    return result

def get_gaia_photometry(source_table):
    id_lst = tuple(map(int, source_table['source_id']))
    cols_to_query =','.join(['source_id'] + [c for c in gaia_zp_table['col']] + [c.replace('mag','flux') + ',' + c.replace('mag','flux') +'_error' for c in gaia_zp_table['col']])
    query = 'SELECT ' + cols_to_query + f' FROM gaiadr3.gaia_source WHERE source_id IN {id_lst}'
    job = Gaia.launch_job(query)
    result = job.get_results()
    for col in gaia_zp_table['col']:
        band = gaia_zp_table[gaia_zp_table['col']==col]['band'][0]
        err_col = gaia_zp_table[gaia_zp_table['col']==col]['err_col'][0]
        f0 = gaia_zp_table[gaia_zp_table['col']==col]['f0'][0]
        m0 = gaia_zp_table[gaia_zp_table['col']==col]['m0'][0]
        err_over_f = result[col.replace('mag','flux')+'_error'] / result[col.replace('mag','flux')] ## fractional error- the fluxes are in arbitrary units, but the fractional error is the same as in erg/s/cm^2/AA
        mag = result[col]
        result[col] = f0 * 10**(-0.4*(mag-m0))
        result[col].unit = u.erg / u.s / u.cm**2 / u.AA
        result[col].name = band
        result[err_col] = result[band] * err_over_f ## to convert fractional error to absolute error in physical units
        result[err_col].unit = u.erg / u.s / u.cm**2 / u.AA
        result[err_col].name = band + '_err'
    if 'SOURCE_ID' in result.colnames:
        result.rename_column('SOURCE_ID','source_id')
    result.keep_columns(['source_id'] + list(gaia_zp_table['band']) + [band + '_err' for band in gaia_zp_table['band']])
    return result

def get_wise_photometry(source_table):
    coords = SkyCoord(source_table['ra'], source_table['dec'], unit=(u.deg, u.deg), frame='icrs')
    wise = Vizier.query_region(coords, radius=2*u.arcsec,catalog=['II/311'])
    if len(wise) == 0:
        wise = Table({'idx':source_table['idx'],'W1mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                      'W2mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                      'W3mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                      'e_W1mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                      'e_W2mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                      'e_W3mag':np.full_like(source_table['idx'], np.nan, dtype=float)})
    else:
        wise = wise[0]
        wise['_q'] = wise['_q'] - 1
        wise['_q'] = source_table['idx'][wise['_q']] 
        wise['_q'].name = 'idx'
    wise.keep_columns(['idx'] + list(wise_zp_table['col']) + list(wise_zp_table['err_col']))
    for col in wise_zp_table['col']:
        band = wise_zp_table[wise_zp_table['col']==col]['band'][0]
        err_col = wise_zp_table[wise_zp_table['col']==col]['err_col'][0]
        f0 = wise_zp_table[wise_zp_table['col']==col]['f0'][0]
        m0 = wise_zp_table[wise_zp_table['col']==col]['m0'][0]
        wise[col] = f0 * 10**(-0.4*(wise[col]-m0))
        wise[col].unit = u.erg / u.s / u.cm**2 / u.AA
        wise[col].name = band
        wise[err_col] = wise[band] * 0.4 * np.log(10) * wise[err_col] ## convert mag error to flux error
        wise[err_col].unit = u.erg / u.s / u.cm**2 / u.AA
        wise[err_col].name = band + '_err'
    for band in wise_zp_table['band']: ## mask out sources with low SNR (only keep sources WISE would flag A or B)
        for i in range(len(wise)):
            if (wise[i][band] / wise[i][band + '_err'] < 3) or np.ma.is_masked(wise[i][band+'_err']):
                wise[i][band] = np.nan
                wise[i][band + '_err'] = np.nan
    for idx in source_table['idx']:
        if idx not in wise['idx']:
            wise.add_row([idx] + [np.nan for col in wise_zp_table['col']] + [np.nan for err_col in wise_zp_table['err_col']])
    wise.sort('idx')
    return wise

def get_twomass_photometry(source_table):
    coords = SkyCoord(source_table['ra'], source_table['dec'], unit=(u.deg, u.deg), frame='icrs')
    twomass = Vizier.query_region(coords, radius=2*u.arcsec,catalog=['II/246'])
    if len(twomass) == 0:
        twomass = Table({'idx':source_table['idx'],'Jmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'Hmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'Kmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'e_Jmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'e_Hmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'e_Kmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'Qflg':np.full_like(source_table['idx'], 'ZZZ', dtype='<U3')})
        quality_flag = twomass['Qflg']
    else:
        twomass = twomass[0]
        twomass['_q'] = twomass['_q'] - 1
        twomass['_q'] = source_table['idx'][twomass['_q']]
        twomass['_q'].name = 'idx'
        quality_flag = twomass['Qflg']
    twomass.keep_columns(['idx'] + list(twomass_zp_table['col']) + list(twomass_zp_table['err_col']))
    for col in twomass_zp_table['col']:
        band = twomass_zp_table[twomass_zp_table['col']==col]['band'][0]
        err_col = twomass_zp_table[twomass_zp_table['col']==col]['err_col'][0]
        f0 = twomass_zp_table[twomass_zp_table['col']==col]['f0'][0]
        m0 = twomass_zp_table[twomass_zp_table['col']==col]['m0'][0]
        twomass[col] = f0 * 10**(-0.4*(twomass[col]-m0))
        twomass[col].unit = u.erg / u.s / u.cm**2 / u.AA
        twomass[col].name = band
        twomass[err_col] = twomass[band] * 0.4 * np.log(10) * twomass[err_col] ## convert mag error to flux error
        twomass[err_col].unit = u.erg / u.s / u.cm**2 / u.AA
        twomass[err_col].name = band + '_err'
    for i in range(len(twomass)):
        for j in [0,1,2]: ## J,H,Ks bands respectively
            if quality_flag[i][j] != 'A': ## mask out sources according to 2MASS quality flag
                twomass[i][twomass_zp_table['band'][j]] = np.nan
                twomass[i][twomass_zp_table['band'][j] + '_err'] = np.nan
    for idx in source_table['idx']:
        if idx not in twomass['idx']:
            twomass.add_row([idx] + [np.nan for col in twomass_zp_table['col']] + [np.nan for err_col in twomass_zp_table['err_col']])
    twomass.sort('idx')
    return twomass

def get_galex_photometry(source_table):
    coords = SkyCoord(source_table['ra'], source_table['dec'], unit=(u.deg, u.deg), frame='icrs')
    galex = Table({key:np.full_like(source_table['idx'],np.nan,float) for key in list(galex_zp_table['col'])+list(galex_zp_table['err_col'])})
    galex['idx'] = source_table['idx']
    galex_dr7 = Vizier.query_region(coords, radius=2*u.arcsec,catalog=['II/335/galex_ais'])
    if len(galex_dr7) == 0:
        galex_dr7 = Table({'idx':source_table['idx'],'nuv_mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                           'nuv_magerr':np.full_like(source_table['idx'], np.nan, dtype=float),'fuv_mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                           'fuv_magerr':np.full_like(source_table['idx'], np.nan, dtype=float)})
    else:
        galex_dr7 = galex_dr7[0]    
        galex_dr7['NUVmag'].name = 'nuv_mag'
        galex_dr7['e_NUVmag'].name = 'nuv_magerr'
        galex_dr7['FUVmag'].name = 'fuv_mag'
        galex_dr7['e_FUVmag'].name = 'fuv_magerr'
        galex_dr7['_q'] = galex_dr7['_q'] - 1
        galex_dr7['_q'] = source_table['idx'][galex_dr7['_q']]
        galex_dr7['_q'].name = 'idx'

    for i in range(len(coords)):
        data = Catalogs.query_region(coords[i], radius=2*u.arcsec,catalog='Galex')
        if len(data) == 0: ## no GALEX data found for this source
            continue 
        for col in galex_zp_table['col']:
            if np.ma.is_masked(data[col][0]): ## GALEX data is missing in this band 
                continue
            err_col = galex_zp_table[galex_zp_table['col']==col]['err_col'][0]
            f0 = galex_zp_table[galex_zp_table['col']==col]['f0'][0]
            m0 = galex_zp_table[galex_zp_table['col']==col]['m0'][0]
            galex[i][col] = f0 * 10**(-0.4*(data[col][0]-m0))
            galex[i][err_col] = galex[i][col] * 0.4 * np.log(10) * data[err_col][0]
            if galex[i]['idx'] in galex_dr7['idx']: ## if there is a match in the GALEX DR7 catalog, use that data instead
                j = np.where(galex_dr7['idx']==galex[i]['idx'])[0][0]
                if not np.ma.is_masked(galex_dr7[j][col]) and not np.isnan(galex_dr7[j][col]):
                    galex[i][col] = f0 * 10**(-0.4*(galex_dr7[j][col]-m0))
                    galex[i][err_col] = galex[i][col] * 0.4 * np.log(10) * galex_dr7[j][err_col]
            
    for col in galex_zp_table['col']:
        band = galex_zp_table[galex_zp_table['col']==col]['band'][0]
        err_col = galex_zp_table[galex_zp_table['col']==col]['err_col'][0]
        galex[col].unit = u.erg / u.s / u.cm**2 / u.AA
        galex[col].name = band
        galex[err_col].unit = u.erg / u.s / u.cm**2 / u.AA
        galex[err_col].name = band + '_err'
    return galex
        
def get_photometry(source_table):
    if not ('idx' in source_table.colnames):
        source_table['idx'] = np.arange(len(source_table))
    tbl = source_table['idx','source_id','ra','dec','parallax','parallax_error','[Fe/H]','Av']
    tbl = table.join(tbl, get_gaia_photometry(source_table), keys='source_id', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_synthetic_photometry(tbl), keys='source_id', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_galex_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_twomass_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_wise_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')

    flux_cols = list(gaia_zp_table['band']) + list(wise_zp_table['band']) + list(twomass_zp_table['band']) + list(synt_zp_table['band']) + list(galex_zp_table['band'])
    for col in flux_cols:
        snr = tbl[col] / tbl[col + '_err']
        for i in range(len(tbl)): 
            if snr[i] > 10: ## if SNR > 10, set error to 10% of flux: minimal error to account for model uncertainties
                tbl[i][col+'_err'] = 0.1 * tbl[i][col]

    Table.sort(tbl,'idx')
    return tbl

def get_photometry_single_source(source):
    
    query = f'''SELECT source_id, ra, dec, parallax, parallax_error, ag_gspphot FROM gaiadr3.gaia_source WHERE source_id = {source['source_id'][0]}'''
    job = Gaia.launch_job(query)
    result = job.get_results() 

    tbl = Table(source)
    
    if 'idx' not in source.colnames: tbl['idx'] = 0
    if 'ra' not in source.colnames: tbl['ra'] = result['ra'][0]
    if 'dec' not in source.colnames: tbl['dec'] = result['dec'][0]
    if 'parallax' not in source.colnames: tbl['parallax'] = result['parallax'][0]
    if 'parallax_error' not in source.colnames: tbl['parallax_error'] = result['parallax_error'][0]
    if 'Av' not in source.colnames: tbl['Av'] = result['ag_gspphot'][0]    

    tbl = tbl['idx','source_id','ra','dec','parallax','parallax_error','Av']
    tbl.add_row(tbl[0])
    tbl = table.join(tbl, get_gaia_photometry(tbl), keys='source_id', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_synthetic_photometry(tbl), keys='source_id', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_galex_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_twomass_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_wise_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')

    flux_cols = list(gaia_zp_table['band']) + list(wise_zp_table['band']) + list(twomass_zp_table['band']) + list(synt_zp_table['band']) + list(galex_zp_table['band'])
    for col in flux_cols:
        snr = tbl[col] / tbl[col + '_err']
        for i in range(len(tbl)): 
            if snr[i] > 10: ## if SNR > 10, set error to 10% of flux: minimal error to account for model uncertainties
                tbl[i][col+'_err'] = 0.1 * tbl[i][col]
    
    return Table(tbl[0])

In [4]:
## model and fitting functions

filter_list = ['Obs.Flux_2MASS.H', 'Obs.Flux_2MASS.J', 'Obs.Flux_2MASS.Ks',
                'Obs.Flux_ACS_WFC.F606W', 'Obs.Flux_ACS_WFC.F814W',
                  'Obs.Flux_APASS.B', 'Obs.Flux_APASS.V',
                    'Obs.Flux_DECam.Y', 'Obs.Flux_DECam.g', 'Obs.Flux_DECam.i', 'Obs.Flux_DECam.r', 'Obs.Flux_DECam.z',
                      'Obs.Flux_DENIS.I', 'Obs.Flux_DENIS.J', 'Obs.Flux_DENIS.Ks',
                        'Obs.Flux_GAIA3.G', 'Obs.Flux_GAIA3.Gbp', 'Obs.Flux_GAIA3.Grp', 'Obs.Flux_GAIA3.Grvs',
                          'Obs.Flux_GALEX.FUV', 'Obs.Flux_GALEX.NUV',
                            'Obs.Flux_Johnson.B', 'Obs.Flux_Johnson.I', 'Obs.Flux_Johnson.R', 'Obs.Flux_Johnson.U', 'Obs.Flux_Johnson.V',
                              'Obs.Flux_PS1.g', 'Obs.Flux_PS1.i', 'Obs.Flux_PS1.r', 'Obs.Flux_PS1.y', 'Obs.Flux_PS1.z',
                                'Obs.Flux_SDSS.g', 'Obs.Flux_SDSS.i', 'Obs.Flux_SDSS.r', 'Obs.Flux_SDSS.u', 'Obs.Flux_SDSS.z',
                                  'Obs.Flux_UKIDSS.K',
                                    'Obs.Flux_VISTA.H', 'Obs.Flux_VISTA.J', 'Obs.Flux_VISTA.Ks', 'Obs.Flux_VISTA.Y',
                                      'Obs.Flux_WISE.W1', 'Obs.Flux_WISE.W2', 'Obs.Flux_WISE.W3']

def get_distance(parallax,parallax_err):
    ## distance in meters calculated from parallax in mas
    return (1000/parallax) * pc.value , (1000/parallax**2)*parallax_err * pc.value

def blackbody_flux(wavelength, temperature):
    ## wavelength and temperature must have astropy units
    f = (2 * np.pi * h.cgs * c.cgs**2 / wavelength.cgs**5) / (np.exp(h.cgs * c.cgs / (wavelength.cgs * k_B.cgs * temperature)) - 1)
    return f.to(u.erg / u.s / u.cm**2 / u.AA)

def blackbody_mod_table(teff,radius,parallax,parallax_err,Av):
    bands = list(bands_table['band'])
    mod_tbl = Table(dict(zip(bands,[[float(1)] for i in range(len(bands))])))
    d, d_err = get_distance(parallax,parallax_err)
    radius = radius * R_sun.value

    AW1,AW2,AW3,AW4 = extinction.get_WISE_extinction(Av,0,teff)
    AJ,AH,AKs = extinction.get_2MASS_extinction(Av,0,teff)
    AU,AB,AV,AR,AI = extinction.get_Johnson_extinction(Av,0,teff)
    Au,Ag,Ar,Ai,Az = extinction.get_SDSS_extinction(Av,0,teff)
    AG,AGbp,AGrp = extinction.get_Gaia_extinction(Av,0,teff)
    AFUV,ANUV = extinction.get_Galex_extinction(Av,0,teff)

    ext_dict = {'2MASS.J':AJ,'2MASS.H':AH,'2MASS.Ks':AKs,'GALEX.FUV':AFUV,'GALEX.NUV':ANUV,'GAIA3.G':AG,'GAIA3.Gbp':AGbp,'GAIA3.Grp':AGrp,
                'WISE.W1':AW1,'WISE.W2':AW2,'WISE.W3':AW3,'WISE.W4':AW4,'Johnson.U':AU,'Johnson.B':AB,'Johnson.V':AV,'Johnson.R':AR,'Johnson.I':AI,
                'SDSS.u':Au,'SDSS.g':Ag,'SDSS.r':Ar,'SDSS.i':Ai,'SDSS.z':Az}

    for b in bands:
        filepath = os.path.join('..','data','VOSA','filters',b+'.dat')
        filter_tbl = Table.read(filepath, format='ascii', names=['wavelength','transmission'])
        wavelength = filter_tbl['wavelength'].data * u.AA
        
        transmission = filter_tbl['transmission'].data 
        flux = blackbody_flux(wavelength, teff * u.K)
        flux = np.dot(flux, transmission)
        flux /= np.sum(transmission)
        flux = flux * (radius/d)**2
        mod_tbl[0][b] = flux.value
        mod_tbl[b].unit = flux.unit 
        mod_tbl[b] = mod_tbl[b] * 10**(-0.4 * ext_dict[b])

    return mod_tbl
        
def get_chi2(obs_tbl_dered, mod_tbl, no_uv = False):
    chi2 = 0
    ndof = 0
    for col in mod_tbl.colnames:
        if no_uv and col.startswith('GALEX'):
           continue
        if col in obs_tbl_dered.colnames:
            if not(np.ma.is_masked(obs_tbl_dered[col]) or np.ma.is_masked(obs_tbl_dered[col + '_err'])):
                chi2 += np.sum((obs_tbl_dered[col] - mod_tbl[col])**2 /(obs_tbl_dered[col+'_err'])**2)
                ndof += 1
    return chi2/ndof

def plot_obs_sed(sources,idx,model=None,title='SED',no_uv = True, no_IR = False):
    warnings.simplefilter('ignore', UserWarning)
    obs_tbl = get_photometry(sources)
    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[idx][bnd] for bnd in bnds])
    flux_err = np.array([obs_tbl[idx][bnd + '_err'] for bnd in bnds])

    if model is not None:
         ## constructing and plotting model SED
        fig, ax = plt.subplots(figsize=(1.5*figwidth_single, figwidth_single), tight_layout=True,dpi=300)
        teff, radius, teff_err,radius_err, chi2 = model
        mod_tbl = blackbody_mod_table(teff,radius,obs_tbl[idx]['parallax'],obs_tbl[idx]['parallax_error'],obs_tbl[idx]['Av'])
        mod_flux = np.array([mod_tbl[bnd][0] for bnd in bnds])
        mod_flux = wl * mod_flux
        ax.plot(wl, mod_flux,color='Navy',ls='--',lw=.5,label=f'BB {int(teff)}K/{radius:.2f} $R_\odot$',marker='d',mec='RoyalBlue',mfc='RoyalBlue',markersize=2, zorder = 0)
        
        ## plotting fit results
        textstr = f'$\chi^2/ndof$= {chi2:.2f} \n' + '$T_{eff,1}=$' +f'{int(teff)}$\pm${int(teff_err)} K\n' + '$R_1=$' + f'{radius:.2f}$\pm${radius_err:.2f} $R_\odot$'
        props = dict(boxstyle='round', facecolor='Silver', alpha=0.9)
        ax.text(0.35, 0.5, textstr, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', bbox=props, color='k')
    else:
       fig,ax = plt.subplots(figsize=(figwidth_single, figwidth_single),dpi=300)
       
    ## plotting observed SED with error bars, mark UV and IR points if not used in fit
    flux = wl * flux
    ax.errorbar(wl,flux,yerr = wl * flux_err,fmt='.',label='observed',ecolor='k',elinewidth=1,mec='Crimson',mfc='Crimson',capsize=2,markersize=2, zorder = 5)
    ax.set_ylabel(r'$\lambda f_\lambda $($erg/s/cm^2/\AA$)',fontsize=8)  
    
    no_fit_cut = ((wl < 2500) * no_uv) | ((wl>10000) * bool(no_IR))
    if np.count_nonzero(~np.isnan(flux[no_fit_cut])) > 0:
        ax.scatter(wl[no_fit_cut],flux[no_fit_cut],marker='.',c='Gold',label='No fit',s=3,zorder = 10)
        
    ## aesthetics
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_title(title,fontsize=8)
    ax.set_xlabel(r'$\lambda$ $(\AA)$',fontsize=10)
    ax.set_ylabel(r'$\lambda f_\lambda$ ($erg/s/cm^2/\AA$)',fontsize=10)  
    ax.tick_params(axis='both', which='major', labelsize=8)
    ax.grid()
    ax.legend(fontsize=10,loc='lower center')
    fig.show()
    return fig

def fit_table(sources, init = (6000,1) ,teff_bounds = (3000,10000), r_bounds = (0.4,3.0), no_uv = True, no_IR = []):
    ## init, teff_bounds, r_bounds either a tuple or a list of tuples of length len(sources)
    ## if list of tuples, each tuple is for a different source
    obs_tbl = get_photometry(sources)
    # obs_tbl_dered = deredden_obs_table(obs_tbl)
    results = Table(np.zeros((len(sources),6)), names= ['teff1','teff1_err','r1','r1_err','chi2/ndof','no_IR'])
    
    if type(init) == tuple:
      init = [init for i in range(len(sources))]
    elif len(init) != len(sources):
      raise ValueError('init must be a tuple or a list of tuples of length len(sources)')
    if type(teff_bounds) == tuple:
      teff_bounds = [teff_bounds for i in range(len(sources))]
    elif len(teff_bounds) != len(sources):
      raise ValueError('teff_bounds must be a tuple or a list of tuples of length len(sources)')
    if type(r_bounds) == tuple:
      r_bounds = [r_bounds for i in range(len(sources))]
    elif len(r_bounds) != len(sources):
      raise ValueError('r_bounds must be a tuple or a list of tuples of length len(sources)')
    
    i = 0
    for obs_row,init,teff_bounds,r_bounds in tqdm(zip(obs_tbl,init,teff_bounds,r_bounds),total=len(sources)):
      idx = sources['idx'][i]
      bnds = list(bands_table['band'])
      if no_uv:
        bnds = [b for b in bnds if not b.startswith('GALEX')]
      if idx in no_IR:
         bnds = [b for b in bnds if not (b.startswith('2MASS') or b.startswith('WISE'))]
      bnds = [b for b in bnds if not (np.ma.is_masked(obs_row[b]) or np.isnan(obs_row[b]))] 
      wl = np.array([bands_table[bands_table['band']==bnd]['lambda_eff'][0] for bnd in bnds])
      
      def bb_fun(wl,teff,radius):
        mod_tbl = blackbody_mod_table(teff,radius,obs_row['parallax'],obs_row['parallax_error'],obs_row['Av'])
        mod_flux = np.array([mod_tbl[bnd][0] for bnd in bnds])
        return mod_flux

      flux = np.array([obs_row[bnd] for bnd in bnds])
      flux_err = np.array([obs_row[bnd + '_err'] for bnd in bnds])
      vals, covm = curve_fit(bb_fun, wl, flux, p0 = list(init), bounds = ([teff_bounds[0],r_bounds[0]],[teff_bounds[1],r_bounds[1]]), sigma = flux_err, absolute_sigma = True)
      results[i]['teff1'] = vals[0]
      results[i]['teff1_err'] = covm[0,0]**(1/2)
      results[i]['r1'] = vals[1]
      results[i]['r1_err'] = covm[1,1]**(1/2)
      results[i]['chi2/ndof'] = np.sum((flux - bb_fun(wl,vals[0],vals[1]))**2 / flux_err**2) / (len(flux) - len(vals))
      results[i]['no_IR'] = idx in no_IR
      i += 1
    return results

# SED models and UV excess
* Yarin- run only first 2 cells

In [6]:
## SED model functions


filter_band_dict = {'GALEX/GALEX.FUV':'GALEX.FUV','GALEX/GALEX.NUV':'GALEX.NUV','GAIA/GAIA3.G':'GAIA3.G','GAIA/GAIA3.Gbp':'GAIA3.Gbp','GAIA/GAIA3.Grp':'GAIA3.Grp',
    'Generic/Johnson.U':'Johnson.U','Generic/Johnson.B':'Johnson.B','Generic/Johnson.V':'Johnson.V','Generic/Johnson.R':'Johnson.R','Generic/Johnson.I':'Johnson.I'
        ,'SLOAN/SDSS.u':'SDSS.u','SLOAN/SDSS.g':'SDSS.g','SLOAN/SDSS.r':'SDSS.r','SLOAN/SDSS.i':'SDSS.i','SLOAN/SDSS.z':'SDSS.z',
        '2MASS/2MASS.J':'2MASS.J','2MASS/2MASS.H':'2MASS.H','2MASS/2MASS.Ks':'2MASS.Ks',
        'WISE/WISE.W1':'WISE.W1','WISE/WISE.W2':'WISE.W2','WISE/WISE.W3':'WISE.W3'}

kurucz_tgrid = np.unique(np.genfromtxt('../data/kurucz_sed.dat',names=True,dtype=None)['teff'])
kurucz_ggrid = np.unique(np.genfromtxt('../data/kurucz_sed.dat',names=True,dtype=None)['logg'])
kurucz_metagrid = np.unique(np.genfromtxt('../data/kurucz_sed.dat',names=True,dtype=None)['meta'])

bedard_tgrid = np.unique(np.genfromtxt('../data/WD_models/CO_DA.dat',names=True,dtype=None)['Teff'])
bedard_ggrid = np.unique(np.genfromtxt('../data/WD_models/CO_DA.dat',names=True,dtype=None)['log_g'])

kurucz = Table(np.genfromtxt('../data/kurucz_sed.dat',names=True,dtype=None))
co_da = Table(np.genfromtxt('../data/WD_models/CO_DA.dat',names=True,dtype=None))
co_db = Table(np.genfromtxt('../data/WD_models/CO_DB.dat',names=True,dtype=None))
he_da = Table(np.genfromtxt('../data/WD_models/He_wd.dat',names=True,dtype=None))
one_da = Table(np.genfromtxt('../data/WD_models/ONe_DA.dat',names=True,dtype=None))
one_db = Table(np.genfromtxt('../data/WD_models/ONe_DB.dat',names=True,dtype=None))

def get_ms_sed(teff,m1,r1,meta,parallax):
    kur = kurucz.copy()
    model_table = Table(data=[[np.nan] for band in bands_table['wd_band']],names=bands_table['wd_band'])
    logg = np.log10((G.cgs * m1 * M_sun.cgs / (r1 *R_sun.cgs)**2).value)
    j_t, j_g, j_meta = np.searchsorted(kurucz_tgrid,teff), np.searchsorted(kurucz_ggrid,logg), np.searchsorted(kurucz_metagrid,meta)
    if j_t ==0 or j_g == 0 or j_meta == 0:
        return model_table
    elif j_t == len(kurucz_tgrid) or j_g == len(kurucz_ggrid) or j_meta == len(kurucz_metagrid):
        return model_table
    kur = kur[np.isin(kur['teff'],[kurucz_tgrid[j_t-1],kurucz_tgrid[j_t]])]
    kur = kur[np.isin(kur['logg'],[kurucz_ggrid[j_g-1],kurucz_ggrid[j_g]])]
    kur = kur[np.isin(kur['meta'],[kurucz_metagrid[j_meta-1],kurucz_metagrid[j_meta]])]
    for band in bands_table['wd_band']:
        interp = LinearNDInterpolator(np.array([kur['teff'],kur['logg'],kur['meta']]).T,kur[band])
        model_table[band] = interp([teff,logg,meta])[0] * (r1 * R_sun.cgs * parallax / 1000 / pc.cgs)**2
        model_table[band].unit = u.erg / u.s / u.cm**2 / u.AA
    return model_table

def logg_from_MR_relation(teff,m,core='CO',atm='H'):
    ## teff in K, m in solar masses
    ## returns log_g in cm/s^2
    if core == 'He':
        wd = he_da.copy()
    elif core == 'CO':
        if atm == 'H':
            wd = co_da.copy()
        elif atm == 'He':
            wd = co_db.copy()
        else:
            print('atm either H or He')
            return None
    elif core == 'ONe':
        if atm == 'H':
            wd = one_da.copy()
        elif atm == 'He':
            wd = one_db.copy()
        else:
            print('atm either H or He')
            return None
    interp = LinearNDInterpolator(np.array([wd['Teff'],wd['Mass']]).T,wd['log_g'])
    logg = interp([teff,m])[0]
    return logg

def get_wd_sed(teff,m,parallax,core='CO',atm='H'):
    ## use MR relation specific to core and atmosphere type
    logg = logg_from_MR_relation(teff,m,core,atm) 
    model_table = Table(data=[[np.nan] for band in bands_table['wd_band']],names=bands_table['wd_band'])
    
    ## then interpolate Teff, Logg on the CO model to get SED
    if atm == 'H':
        model = co_da.copy()
    elif atm == 'He':
        model = co_db.copy()
    else:
        print('atm either H or He')
        return None
    j_t, j_g = np.searchsorted(bedard_tgrid,teff), np.searchsorted(bedard_ggrid,logg)
    if j_t ==0 or j_g == 0:
        return model_table
    elif j_t == len(bedard_tgrid) or j_g == len(bedard_ggrid):
        return model_table
    
    cut = np.isin(model['Teff'],[bedard_tgrid[j_t-1],bedard_tgrid[j_t]]) & np.isin(model['log_g'],[bedard_ggrid[j_g-1],bedard_ggrid[j_g]])
    if np.count_nonzero(cut) > 2: # if enough points for interpolation, can save time
        model = model[cut]

    for wd_band,f0,m0 in zip(bands_table['wd_band'],bands_table['f0'],bands_table['m0']): 
        try:
            interp = LinearNDInterpolator(np.array([model['Teff'],model['log_g']]).T,model[wd_band])
            absmag = interp([teff,logg])[0]
            apmag = absmag + 5 * np.log10(1000/parallax) - 5
            model_table[wd_band] = f0 * 10**(-0.4*(apmag - m0))
        except:
            model_table[wd_band] = np.nan
        model_table[wd_band].unit = u.erg / u.s / u.cm**2 / u.AA
    return model_table

def get_wd_koester_sed(teff,m,parallax,core='CO',atm='H'):
    ## use MR relation specific to core and atmosphere type
    logg = logg_from_MR_relation(teff,m,core,atm)
    r = np.sqrt(G.cgs * m * M_sun.cgs / (10**logg) / (u.cm/u.s**2))/R_sun.cgs ## in R_sun

    ## then interpolate Teff, Logg on the koester model to get SED
    model = np.genfromtxt('../data/koester_sed.dat',names=True,dtype=None)
    model_table = Table(data=[[np.nan] for band in bands_table['wd_band']],names=bands_table['wd_band'])
    for wd_band in bands_table['wd_band']:
        interp = LinearNDInterpolator(np.array([model['teff'],model['logg']]).T,model[wd_band])
        model_table[wd_band] = interp([teff,logg])[0] * (r * R_sun.cgs * parallax / 1000 / pc.cgs)**2
        model_table[wd_band].unit = u.erg / u.s / u.cm**2 / u.AA
    return model_table

def redden_model_table(mod_tbl,teff,av):
    AW1,AW2,AW3,AW4 = extinction.get_WISE_extinction(av,0,teff)
    AJ,AH,AKs = extinction.get_2MASS_extinction(av,0,teff)
    AU,AB,AV,AR,AI = extinction.get_Johnson_extinction(av,0,teff)
    Au,Ag,Ar,Ai,Az = extinction.get_SDSS_extinction(av,0,teff)
    AG,AGbp,AGrp = extinction.get_Gaia_extinction(av,0,teff)
    AFUV,ANUV = extinction.get_Galex_extinction(av,0,teff)
    ext_dict = {'2MASS.J':AJ,'2MASS.H':AH,'2MASS.Ks':AKs,'GALEX.FUV':AFUV,'GALEX.NUV':ANUV,'GAIA3.G':AG,'GAIA3.Gbp':AGbp,'GAIA3.Grp':AGrp,
                'WISE.W1':AW1,'WISE.W2':AW2,'WISE.W3':AW3,'WISE.W4':AW4,'Johnson.U':AU,'Johnson.B':AB,'Johnson.V':AV,'Johnson.R':AR,'Johnson.I':AI,
                'SDSS.u':Au,'SDSS.g':Ag,'SDSS.r':Ar,'SDSS.i':Ai,'SDSS.z':Az}
    for band,col in zip(bands_table['band'],bands_table['wd_band']):
        mod_tbl[col][0] = mod_tbl[col][0] * 10**(-0.4 * ext_dict[band])
        
    return mod_tbl

def get_cooling_age(teff,m,core='CO',atm='H'): # WD cooling age for given teff, mass, core, atm
    if core =='He':
        model = np.genfromtxt('../data/WD_models/He_wd.dat',names=True,dtype=None)
    elif core == 'CO':
        if atm == 'H':
            model = np.genfromtxt('../data/WD_models/CO_DA.dat',names=True,dtype=None)
        elif atm == 'He':
            model = np.genfromtxt('../data/WD_models/CO_DB.dat',names=True,dtype=None)
        else:
            print('atm either H or He')
            return None
    elif core == 'ONe':
        if atm == 'H':
            model = np.genfromtxt('../data/WD_models/ONe_DA.dat',names=True,dtype=None)
        elif atm == 'He':
            model = np.genfromtxt('../data/WD_models/ONe_DB.dat',names=True,dtype=None)
        else:
            print('atm either H or He')
            return None
    # logg = logg_from_MR_relation(teff,m,core,atm)
    interp = LinearNDInterpolator(np.array([model['Teff'],model['Mass']]).T,model['Age'])
    age = interp([teff,m])[0]
    if np.isnan(age):
        return np.nan
    return age

def get_cooling_temp(age,m,core='CO',atm='H'): # WD cooling temp for given age, mass, core, atm
    if core =='He':
        model = np.genfromtxt('../data/WD_models/He_wd.dat',names=True,dtype=None)
    elif core == 'CO':
        if atm == 'H':
            model = np.genfromtxt('../data/WD_models/CO_DA.dat',names=True,dtype=None)
        elif atm == 'He':
            model = np.genfromtxt('../data/WD_models/CO_DB.dat',names=True,dtype=None)
        else:
            print('atm either H or He')
            return None
    elif core == 'ONe':
        if atm == 'H':
            model = np.genfromtxt('../data/WD_models/ONe_DA.dat',names=True,dtype=None)
        elif atm == 'He':
            model = np.genfromtxt('../data/WD_models/ONe_DB.dat',names=True,dtype=None)
        else:
            print('atm either H or He')
            return None
    interp = LinearNDInterpolator(np.array([model['Age'],model['Mass']]).T,model['Teff'])
    teff_init = interp([age,m])[0]
    teff_vec = np.linspace(teff_init+10000,teff_init-10000,20)
    age_vec = [get_cooling_age(teff,m,core,atm) for teff in teff_vec]
    teff = np.interp(age,age_vec,teff_vec)
    return teff


In [7]:
## plotting functions

def plot_kurucz_vs_obs(sources,idx,plot=True,save=False,ax=None,no_IR=False):
    warnings.simplefilter('ignore', UserWarning)
    i = np.where(sources['idx'] == idx)[0][0]
    obs_tbl = get_photometry_single_source(Table(sources[i]))

    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[0][bnd] for bnd in bnds])
    flux = flux * wl
    flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds])
    flux_err = flux_err * wl

    teff1 = sources[i]['teff1']
    teff1_err = sources[i]['teff1_err']
    r1 = sources[i]['r1']
    r1_err = sources[i]['r1_err']    
    m = sources[i]['m1']
    parallax = sources[i]['parallax']
    meta = sources[i]['[Fe/H]']
    av = sources[i]['Av']
    ms = get_ms_sed(teff1, m, r1, meta, parallax)
    ms = redden_model_table(ms,teff1,av)

    flux_bst = np.array([ms[bnd][0] for bnd in bands_table['wd_band']]) * wl

    make_xlabel = False
    if ax is None:
        fig, ax = plt.subplots(figsize=(1.5*figwidth_single, figwidth_single), tight_layout=True,dpi=400)
        make_xlabel = True

    ## displaying fit results
    cut = (wl>3000) & ~np.isnan(flux) & ~np.isnan(flux_err)
    cut = cut.astype(bool)
    chi2 = np.sum((flux[cut] - flux_bst[cut])**2 / flux_err[cut]**2) / (len(flux[cut]) - 2)
    sig_dig_t = -int(np.floor(np.log10(teff1_err)))
    sig_dig_r = -int(np.floor(np.log10(r1_err)))


    ax.plot(wl,flux_bst,color='Navy',ls='--',lw=1,label=f'Best-fit MS model, $\chi^2/ndof$= {chi2:.2f} \n' + r'$T_{\text{eff},1}=$' 
                +f'{int(np.round(teff1,sig_dig_t))}$\pm${int(np.round(teff1_err,sig_dig_t))} K, ' + '$R_1=$' + f'{np.round(r1,sig_dig_r)}$\pm${np.round(r1_err,sig_dig_r)} $R_\odot$',zorder=5)
    
    ## displaying observed SED
    ax.scatter(wl,flux,marker='.',color='Crimson',s=60,zorder = 10,label=f'Photometry for candidate {idx}')
    if len(flux[wl<3000]) > 0:
        ax.scatter(wl[wl<3000],flux[wl<3000],marker='.',color='k',s=60,zorder = 10,label='Excluded from fit')
    if no_IR:
        ax.scatter(wl[wl>10000],flux[wl>10000],marker='.',color='k',s=60,zorder = 10) ## testing no_IR
    
    if make_xlabel:
        ax.set_xscale('log')
        ax.tick_params(axis='x', which='major')
        ax.set_xlabel(r'Wavelength $(\AA)$',fontsize=10)
        fig.show()
    ax.legend(fontsize=9,loc='lower center',frameon=False)
    # ax.set_title(f'Candidate {idx} fit to MS',fontsize=8)
    ax.set_yscale('log')
    ax.set_ylabel(r'$\lambda f_\lambda $ (erg$\,$s$^{-1}\,$cm$^{-2}$)',labelpad=-1.2,fontsize=12) 
    ax.tick_params(axis='y', which='major', labelsize=10)
    ax.minorticks_on()
    if save:
        fig.savefig(f'../img/kurucz/kurucz_{idx}.png')
    if not plot:
        plt.close(fig)

def plot_mswd_vs_obs(sources,idx,ms_params,wd_params,ax=None,plot=True,save=False): ## gets ms_params = (teff1,r1) and wd_params = (teff2,core,atm)
    warnings.simplefilter('ignore', UserWarning)
    i = np.where(sources['idx'] == idx)[0][0]
    obs_tbl = get_photometry_single_source(Table(sources[i]))

    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[0][bnd] for bnd in bnds])
    flux = flux * wl
    flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds])
    flux_err = flux_err * wl

    m1,m2 = sources[i]['m1'],sources[i]['m2']
    parallax, meta, av = sources[i]['parallax'], sources[i]['[Fe/H]'], sources[i]['Av']
    t1, r1 = ms_params
    t2, core, atm = wd_params

    ms = get_ms_sed(t1, m1, r1, meta, parallax)
    ms = redden_model_table(ms,t1,av)
    wd = get_wd_sed(t2, m2, parallax, core, atm)
    wd = redden_model_table(wd,t2,av)
    wd_model_flux = np.array([wd[bnd][0] for bnd in bands_table['wd_band']]) * wl
    ms_model_flux = np.array([ms[bnd][0] for bnd in bands_table['wd_band']]) * wl
    model_flux = wd_model_flux + ms_model_flux

    if ax is None:
        fig, ax = plt.subplots(figsize=(1.5*figwidth_single, figwidth_single), tight_layout=True,dpi=300)
        ax.set_xlabel(r'wavelength $(\AA)$')
        ax.set_xscale('log')
    
    ax.plot(wl,model_flux,color='k',ls='-',lw=1,label=f'Best-fit MS + WD to candidate {idx}',zorder=5)
    ax.plot(wl,wd_model_flux,color='Navy',ls='dotted',lw=1,label=f'WD model, {int(np.round(t2,-2))}K, {core} core, {atm} atm',zorder=0)
    ax.plot(wl,ms_model_flux,color='Navy',ls='dashed',lw=1,label=f'MS model, {int(np.round(t1,-2))}K, {r1:.2f} $R_\odot$',zorder=0)
    ax.scatter(wl,flux,marker='.',color='Crimson',s=60,zorder = 10)
    ax.set_ylabel(r'$\lambda f_\lambda $ (erg$\,$s$^{-1}\,$cm$^{-2}$)',labelpad=0,fontsize=12) 
    ax.set_yscale('log')
    ax.legend(fontsize=9,loc='lower left',frameon=False)
    ax.tick_params(axis='both', which='major', labelsize=10)
    if not plot:
        plt.close(fig)
    if save:
        fig.savefig(f'../img/binary_fit/binary_fit_{idx}.png')

def plot_msbb_vs_obs(sources,idx,ms_params,bb_params,ax=None,plot=True,save=False): ## gets ms_params = (teff1,r1) and bb_params = (teff2,r2)
    warnings.simplefilter('ignore', UserWarning)
    i = np.where(sources['idx'] == idx)[0][0]
    obs_tbl = get_photometry_single_source(Table(sources[i]))

    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[0][bnd] for bnd in bnds])
    flux = flux * wl
    flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds])
    flux_err = flux_err * wl

    m1,m2 = sources[i]['m1'],sources[i]['m2']
    parallax, meta, av = sources[i]['parallax'], sources[i]['[Fe/H]'], sources[i]['Av']
    t1, r1 = ms_params
    t2, r2 = bb_params

    ms = get_ms_sed(t1, m1, r1, meta, parallax)
    ms = redden_model_table(ms,t1,av)
    bb = blackbody_mod_table(t2,r2,parallax,0,av)
    bb_model_flux = np.array([bb[bnd][0] for bnd in bands_table['band']]) * wl
    ms_model_flux = np.array([ms[bnd][0] for bnd in bands_table['wd_band']]) * wl
    model_flux = bb_model_flux + ms_model_flux

    if ax is None:
        fig, ax = plt.subplots(figsize=(1.5*figwidth_single, figwidth_single), tight_layout=True,dpi=300)
        ax.set_xlabel(r'wavelength $(\AA)$')
        ax.set_xscale('log')
    
    ax.plot(wl,model_flux,color='k',ls='-',lw=0.5,label=f'Best-fit MS + BB to candidate {idx}',zorder=5)
    ax.plot(wl,bb_model_flux,color='Navy',ls='dotted',lw=0.5,label=f'BB model, {int(t2)}K, {r2:.2f} $R_\odot$',zorder=0)
    ax.plot(wl,ms_model_flux,color='Navy',ls='dashed',lw=0.5,label=f'MS model, {int(t1)}K, {r1:.2f} $R_\odot$',zorder=0)
    ax.scatter(wl,flux,marker='.',color='Crimson',s=30,zorder = 10)
    ax.set_ylabel(r'$\lambda f_\lambda $ (erg$\,$s$^{-1}\,$cm$^{-2}$)',labelpad=0,fontsize=12) 
    ax.set_yscale('log')
    ax.set_ylim(5e-18)
    ax.legend(fontsize=9,loc='lower left',frameon=False)
    ax.tick_params(axis='both', which='major', labelsize=10)
    if not plot:
        plt.close(fig)
    if save:
        fig.savefig(f'../img/binary_fit/msbb_{idx}.png')

def plot_kurucz_residuals(sources,idx,ax1=None,ax2=None,no_IR=False):
    j = np.where(sources['idx'] == idx)[0][0]
    obs_tbl = get_photometry_single_source(Table(sources[j]))
    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[0][bnd] for bnd in bnds]) * wl
    flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds]) * wl
    
    # fixed parameters
    parallax = sources[j]['parallax']
    meta = sources[j]['[Fe/H]']
    av = sources[j]['Av']
    m1 = sources[j]['m1']

    t1 = sources[j]['teff1']
    r1 = sources[j]['r1']

    ms = get_ms_sed(t1,m1,r1,meta,parallax)
    ms = redden_model_table(ms,t1,av)
    model_flux = np.array([ms[bnd][0] for bnd in bands_table['wd_band']]) * wl
    residuals = (flux - model_flux) / model_flux
    res_error = flux_err / model_flux

    make_xlabel = False
    if ax1 is None:
        fig,ax1 = plt.subplots(figsize=(7.2,2.4), tight_layout=True,dpi=400)
        make_xlabel = True

    ax1.errorbar(wl,residuals,yerr = res_error,fmt='.',label='observed',ecolor='k',elinewidth=1,mec='Crimson',mfc='Crimson',capsize=0,markersize=7)
    ax1.scatter(wl[wl<3000],residuals[wl<3000],marker='.',color='k',s=40,zorder = 10)
    if no_IR:
        ax1.scatter(wl[wl>10000],residuals[wl>10000],marker='.',color='k',s=40,zorder = 10) ## testing no_IR
    ax1.set_ylabel(r'Residuals',fontsize=10,labelpad=0)
    ax1.tick_params(axis='y', which='major', labelsize=8)
    ax1.axhline(0,color='Navy',ls='--',lw=1,label='model')

    # broken axis
    if ax2 is not None:
        ax1.set_ylim(0.9 * np.nanmin(residuals)-0.15,0.22)
        ax2.set_ylim(0.8 * np.nanmax(residuals),1.2 * np.nanmax(residuals))
        ax2.errorbar(wl,residuals,yerr = res_error,fmt='.',label='observed',ecolor='k',elinewidth=1,mec='Crimson',mfc='Crimson',capsize=0,markersize=7)
        ax2.scatter(wl[wl<3000],residuals[wl<3000],marker='.',color='k',s=40,zorder = 10)
        if no_IR:
            ax2.scatter(wl[wl>10000],residuals[wl>10000],marker='.',color='k',s=40,zorder = 10) ## testing no_IR
        # broken axis aesthetics
        ax1.spines.top.set_visible(False)
        ax2.spines.bottom.set_visible(False)
        ax1.xaxis.tick_bottom()
        ax2.xaxis.set_visible(False)
        ax1.tick_params(labeltop=False)
        ax2.tick_params(axis='y', which='major', labelsize=8)
        d = .05  # proportion of vertical to horizontal extent of the slanted line
        kwargs = dict(marker=[(-1, -d), (1, d)], markersize=7,
                    linestyle="none", color='k', mec='k', mew=1, clip_on=False)
        ax2.plot([0, 1], [0, 0], transform=ax2.transAxes, **kwargs)
        ax1.plot([0, 1], [1, 1], transform=ax1.transAxes, **kwargs)
        ax1.set_ylabel(r'        Residuals',fontsize=12,labelpad=0)
        

    if make_xlabel:
        ax1.set_xscale('log')
        ax1.set_xlabel(r'Wavelength $(\AA)$',fontsize=10)
        fig.show()

def plot_mswd_residuals(sources,idx,ms_params,wd_params,ax1):
    j = np.where(sources['idx'] == idx)[0][0]
    obs_tbl = get_photometry_single_source(Table(sources[j]))
    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[0][bnd] for bnd in bnds]) * wl
    flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds]) * wl
    
    # fixed parameters
    parallax = sources[j]['parallax']
    meta = sources[j]['[Fe/H]']
    av = sources[j]['Av']
    m1 = sources[j]['m1']
    m2 = sources[j]['m2']

    t1, r1 = ms_params
    t2, core, atm = wd_params

    ms = get_ms_sed(t1,m1,r1,meta,parallax)
    ms = redden_model_table(ms,t1,av)
    wd = get_wd_sed(t2,m2,parallax,core,atm)
    wd = redden_model_table(wd,t2,av)
    model_flux = np.array([wd[bnd][0] for bnd in bands_table['wd_band']]) * wl + np.array([ms[bnd][0] for bnd in bands_table['wd_band']]) * wl
    residuals = (flux - model_flux) / model_flux
    res_error = flux_err / model_flux

    ax1.errorbar(wl,residuals,yerr = res_error,fmt='.',label='observed',ecolor='k',elinewidth=1,mec='Crimson',mfc='Crimson',capsize=0,markersize=7)
    ax1.set_ylabel(r'Residuals',fontsize=12)
    ax1.tick_params(axis='both', which='major', labelsize=8)
    ax1.axhline(0,color='k',ls='-',lw=1,label='model')

def plot_msbb_residuals(sources,idx,ms_params,bb_params,ax1):
    j = np.where(sources['idx'] == idx)[0][0]
    obs_tbl = get_photometry_single_source(Table(sources[j]))
    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[0][bnd] for bnd in bnds]) * wl
    flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds]) * wl
    
    # fixed parameters
    parallax = sources[j]['parallax']
    meta = sources[j]['[Fe/H]']
    av = sources[j]['Av']
    m1 = sources[j]['m1']
    m2 = sources[j]['m2']

    t1, r1 = ms_params
    t2, r2 = bb_params

    ms = get_ms_sed(t1,m1,r1,meta,parallax)
    ms = redden_model_table(ms,t1,av)
    bb = blackbody_mod_table(t2,r2,parallax,0,av)
    model_flux = np.array([bb[bnd][0] for bnd in bands_table['band']]) * wl + np.array([ms[bnd][0] for bnd in bands_table['wd_band']]) * wl
    residuals = (flux - model_flux) / model_flux
    res_error = flux_err / model_flux
    ax1.set_yticks([-0.2,0,0.2])
    ax1.errorbar(wl,residuals,yerr = res_error,fmt='.',label='observed',ecolor='k',elinewidth=0.5,mec='Crimson',mfc='Crimson',capsize=0,markersize=5)
    ax1.set_ylabel(r'Residuals',fontsize=12)
    ax1.tick_params(axis='both', which='major', labelsize=8)
    ax1.axhline(0,color='k',ls='-',lw=0.5,label='model')

def plot_kurucz_with_residuals(sources,idx,plot=True,save=False,no_IR=False):
    fig = plt.figure(figsize=(1.5*figwidth_single,figwidth_single),tight_layout=True, dpi=400)
    
    broken = False
    if idx in [53,68,175,194,249,281,283,133]:
        broken = True
    # regular axis
    if not broken:
        gs = fig.add_gridspec(2, 1, height_ratios=[3, 1], hspace=0)
        ax = gs.subplots(sharex=True)
        plot_kurucz_vs_obs(sources,idx,plot=True,save=False,ax=ax[0],no_IR=no_IR)
        plot_kurucz_residuals(sources,idx,ax1=ax[1],no_IR=no_IR)
        ax[1].set_xscale('log')
        ax[1].set_xlabel(r'Wavelength $(\AA)$',fontsize=12)
        ax[1].tick_params(axis='x', which='major', labelsize=10)

    # broken axis 
    if broken:
        gs = fig.add_gridspec(3, 1, height_ratios=[3, 1/2, 1/2])
        gs.update(hspace=0.03)
        ax = gs.subplots(sharex=True)
        ax[0].set_position([0.125,0.31,0.8,0.60])
        ax[1].set_position([0.125,0.23,0.8,0.08])
        ax[2].set_position([0.125,0.12,0.8,0.08])
        plot_kurucz_vs_obs(sources,idx,plot=True,save=False,ax=ax[0],no_IR=no_IR)
        plot_kurucz_residuals(sources,idx,ax1=ax[2],ax2=ax[1],no_IR=no_IR)
        ax[0].xaxis.set_visible(False)
        ax[0].minorticks_on()
        ax[2].set_xscale('log')
        ax[2].set_xlabel(r'Wavelength $(\AA)$',fontsize=12,labelpad=-2)
        ax[2].tick_params(axis='x', which='major', labelsize=10)
    if save:
        fig.savefig(f'../img/kurucz/kurucz_{idx}.png')
    if not plot:
        plt.close(fig)

def plot_mswd_with_residuals(sources,idx,ms_params,wd_params,plot=True,save=False):
    fig = plt.figure(figsize=(1.5*figwidth_single,figwidth_single),tight_layout=True, dpi=400)

    # regular axis
    gs = fig.add_gridspec(2, 1, height_ratios=[3, 1], hspace=0)
    ax = gs.subplots(sharex=True)
    plot_mswd_vs_obs(sources,idx,ms_params,wd_params,ax=ax[0],plot=True,save=False)
    plot_mswd_residuals(sources,idx,ms_params,wd_params,ax1=ax[1])
    ax[1].set_xscale('log')
    ax[1].set_xlabel(r'Wavelength $(\AA)$',fontsize=12)
    ax[1].tick_params(axis='x', which='major', labelsize=10)
    if save:
        fig.savefig(f'../img/binary_fit/binary_fit_{idx}.png')
    if not plot:
        plt.close(fig)

def plot_msbb_with_residuals(sources,idx,ms_params,bb_params,plot=True,save=False):
    fig = plt.figure(figsize=(1.5*figwidth_single,figwidth_single),tight_layout=True, dpi=400)

    # regular axis
    gs = fig.add_gridspec(2, 1, height_ratios=[3, 1], hspace=0)
    ax = gs.subplots(sharex=True)
    plot_msbb_vs_obs(sources,idx,ms_params,bb_params,ax=ax[0],plot=True,save=False)
    plot_msbb_residuals(sources,idx,ms_params,bb_params,ax1=ax[1])
    ax[1].set_xscale('log')
    ax[1].set_xlabel(r'Wavelength $(\AA)$',fontsize=12)
    ax[1].tick_params(axis='x', which='major', labelsize=10)
    if save:
        fig.savefig(f'../img/binary_fit/msbb_{idx}.png')
    if not plot:
        plt.close(fig)

In [8]:
sources = Table.read('../table_C.fits')
idx = 175
# plot_kurucz_with_residuals(sources,idx,plot=False,save=True,no_IR=False)

In [9]:
sources = Table.read('../table_C.fits', format='fits')
# for idx in sources['idx']:
#     plot_kurucz_with_residuals(sources, idx,plot = False, save= True)
# plot_kurucz_with_residuals(sources,283,plot = False, save= True)
idx = 283
# plot_kurucz_with_residuals(sources,idx,plot = False, save= True)
chain = Table.read(f'../data/chains/chains_{idx}.csv')
m2 = sources[sources['idx'] == idx]['m2'][0]
if m2 < 0.45:
    core = 'He'
    atm = 'H'
elif m2 < 1.1:
    core = 'CO'
    atm = 'H'
else:
    core='ONe'
    atm='H'
# plot_mswd_with_residuals(sources,idx,(np.median(chain['teff1']),np.median(chain['r1'])),(np.median(chain['teff2']),core,atm),plot = False, save= True)

In [10]:
# plot_mswd_with_residuals(sources,175, (6312,1.41),(11480,'CO','H'),plot = False, save= True)

In [11]:
## UV excess functions

def obs_nuv_flux(mag):
    f0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.NUV']['f0'].data[0]
    m0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.NUV']['m0'].data[0]
    f = 10**(-0.4 * (mag - m0)) * f0
    return f

def obs_fuv_flux(mag):
    f0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.FUV']['f0'].data[0]
    m0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.FUV']['m0'].data[0]
    f = 10**(-0.4 * (mag - m0)) * f0
    return f

def nuv_flux_to_mag(flux):
    f0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.NUV']['f0'].data[0]
    m0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.NUV']['m0'].data[0]
    return m0 - 2.5 * np.log10(flux / f0)

def fuv_flux_to_mag(flux):
    f0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.FUV']['f0'].data[0]
    m0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.FUV']['m0'].data[0]
    return m0 - 2.5 * np.log10(flux / f0)

def get_primary_nuv_lims(t_lo,t,t_hi,m_lo,m,m_hi,r_lo,r,r_hi,parallax,meta,av):
    ## get the NUV flux range of the primary star
    ms_hi = get_ms_sed(t_hi,m_hi,r_hi,meta,parallax)
    ms = get_ms_sed(t,m,r,meta,parallax)
    ms_lo = get_ms_sed(t_lo,m_lo,r_lo,meta,parallax)
    _, ANUV_lo = extinction.get_Galex_extinction(av,0,t_lo)
    _, ANUV = extinction.get_Galex_extinction(av,0,t)
    _, ANUV_hi = extinction.get_Galex_extinction(av,0,t_hi)
    nuv_hi = ms_hi['NUV'][0]
    nuv = ms['NUV'][0]
    nuv_lo = ms_lo['NUV'][0]
    nuv_hi = nuv_hi * 10**(-0.4 * ANUV_hi)
    nuv = nuv * 10**(-0.4 * ANUV)
    nuv_lo = nuv_lo * 10**(-0.4 * ANUV_lo)
    return nuv_lo, nuv, nuv_hi

def get_primary_fuv_lims(t_lo,t,t_hi,m_lo,m,m_hi,r_lo,r,r_hi,parallax,meta,av):
    ms_hi = get_ms_sed(t_hi,m_hi,r_hi,meta,parallax)
    ms = get_ms_sed(t,m,r,meta,parallax)
    ms_lo = get_ms_sed(t_lo,m_lo,r_lo,meta,parallax)
    fuv_hi = ms_hi['FUV'][0]
    fuv = ms['FUV'][0]
    fuv_lo = ms_lo['FUV'][0]
    AFUV_lo,_ = extinction.get_Galex_extinction(av,0,t_lo)
    AFUV,_ = extinction.get_Galex_extinction(av,0,t)
    AFUV_hi,_ = extinction.get_Galex_extinction(av,0,t_hi)
    fuv_hi = fuv_hi * 10**(-0.4 * AFUV_hi)
    fuv = fuv * 10**(-0.4 * AFUV)
    fuv_lo = fuv_lo * 10**(-0.4 * AFUV_lo)
    return fuv_lo, fuv, fuv_hi

def get_companion_RT_lims_fuv(obs_mag,fuv_lo,fuv,fuv_hi,m2,parallax,av,core='CO',atm='H'):
    if core == 'He':
        wd_path = '../data/WD_models/He_wd.dat'
        wd_teff_grid = np.linspace(5000,25000,40)
    elif core == 'CO': 
        wd_teff_grid = np.linspace(3000,50000,30)
        # wd_teff_grid = np.linspace(50000,150000,30)
        if atm == 'H': wd_path = '../data/WD_models/CO_DA.dat'
        elif atm == 'He': wd_path = '../data/WD_models/CO_DB.dat'
    elif core == 'ONe':
        wd_teff_grid = np.linspace(3000,50000,30)
        # wd_teff_grid = np.linspace(50000,150000,30)
        if atm == 'H': wd_path = '../data/WD_models/ONe_DA.dat'
        elif atm == 'He': wd_path = '../data/WD_models/ONe_DB.dat'

    def wd_reddened_fuv(t,m2,parallax,av,core,atm): ## auxiliary: for solving Teff2
        wd = get_wd_sed(t,m2,parallax,core,atm)
        fuv = wd['FUV'][0]
        AFUV,_ = extinction.get_Galex_extinction(av,0,t)
        fuv = fuv * 10**(-0.4 * AFUV)
        return fuv

    # model0 = np.genfromtxt(wd_path,names=True,dtype=None)
    
    logg_grid = np.array([logg_from_MR_relation(t,m2,core,atm) for t in wd_teff_grid])
    cut = (logg_grid >=7) & (logg_grid <= 9) & ~(np.isnan(logg_grid))
    wd_teff_grid = wd_teff_grid[cut]

    obs_flux = obs_fuv_flux(obs_mag)
    excess_hi = obs_flux - fuv_lo
    excess_lo = obs_flux - fuv_hi
    excess = obs_flux - fuv

    # try:
    #     t = bisect(lambda t: wd_reddened_fuv(t,m2,parallax,av,core,atm)/excess - 1,wd_teff_grid.min(),wd_teff_grid.max(),xtol = 300)
    # except: 
    #     t = 1 * np.sign((wd_reddened_fuv(wd_teff_grid.min(),m2,parallax,av,core,atm) - excess)) ## +1 if not enough observed excess, -1 if too much observed excess 
    try:
        fvec = np.array([wd_reddened_fuv(t,m2,parallax,av,core,atm)/excess -1 for t in wd_teff_grid])
        wd_teff_grid = wd_teff_grid[~np.isnan(fvec)]
        fvec = fvec[~np.isnan(fvec)]
        t = interp1d(fvec,wd_teff_grid,kind='linear')(0)
    except:
        t = np.sign(excess) ## +1 if there's excess, -1 if there's deficit
    return t

def get_companion_RT_lims_nuv(obs_mag,nuv_lo,nuv,nuv_hi,m2,parallax,av,core='CO',atm='H'):
    if core == 'He':
        wd_path = '../data/WD_models/He_wd.dat'
        wd_teff_grid = np.linspace(5000,15000,50)
    elif core == 'CO': 
        wd_teff_grid = np.linspace(3000,50000,30)
        # wd_teff_grid = np.linspace(50000,150000,30)
        if atm == 'H': wd_path = '../data/WD_models/CO_DA.dat'
        elif atm == 'He': wd_path = '../data/WD_models/CO_DB.dat'
    elif core == 'ONe':
        wd_teff_grid = np.linspace(3000,50000,30)
        # wd_teff_grid = np.linspace(50000,150000,30)
        if atm == 'H': wd_path = '../data/WD_models/ONe_DA.dat'
        elif atm == 'He': wd_path = '../data/WD_models/ONe_DB.dat'
    
    def wd_reddened_nuv(t,m2,parallax,av,core,atm): ## auxiliary: for solving Teff2
        wd = get_wd_sed(t,m2,parallax,core,atm)
        nuv = wd['NUV'][0]
        _,ANUV = extinction.get_Galex_extinction(av,0,t)
        nuv = nuv * 10**(-0.4 * ANUV)
        return nuv

    # model0 = np.genfromtxt(wd_path,names=True,dtype=None)
    logg_grid = np.array([logg_from_MR_relation(t,m2,core,atm) for t in wd_teff_grid])
    cut = (logg_grid >=7) & (logg_grid <= 9) & ~(np.isnan(logg_grid))
    wd_teff_grid = wd_teff_grid[cut]
    
    obs_flux = obs_nuv_flux(obs_mag)
    excess_hi = obs_flux - nuv_lo
    excess_lo = obs_flux - nuv_hi
    excess = obs_flux - nuv
    # try:
    #     t = bisect(lambda t: wd_reddened_nuv(t,m2,parallax,av,core,atm)/excess - 1,wd_teff_grid.min(),wd_teff_grid.max(),xtol = 300)
    # except: 
    #     t = 1 * np.sign((wd_reddened_nuv(wd_teff_grid.min(),m2,parallax,av,core,atm) - excess)) ## +1 if not enough observed excess, -1 if too much observed excess
    try:
        fvec = np.array([wd_reddened_nuv(t,m2,parallax,av,core,atm)/excess -1 for t in wd_teff_grid])
        wd_teff_grid = wd_teff_grid[~np.isnan(fvec)]
        fvec = fvec[~np.isnan(fvec)]
        t = interp1d(fvec,wd_teff_grid,kind='linear')(0)
    except:
        t = np.sign(excess) ## +1 if there's excess, -1 if there's deficit
    return t

def assess_uv_excess(sources,idx):
    i = np.where(sources['idx'] == idx)[0][0]
    t_lo,t,t_hi = sources[i]['teff1'] * 0.9,sources[i]['teff1'] ,sources[i]['teff1']*1.1
    m_lo,m,m_hi = sources[i]['m1'] - sources[i]['m1_err'], sources[i]['m1'], sources[i]['m1'] + sources[i]['m1_err']
    r_lo,r,r_hi = sources[i]['r1'] * 0.9, sources[i]['r1'], sources[i]['r1'] * 1.1
    parallax = sources[i]['parallax']
    meta = sources[i]['[Fe/H]']
    age = sources[i]['age']
    av = sources[i]['Av']
    m2 = sources[i]['m2']
    
    nuv_mag = sources[i]['nuv_mag']
    fuv_mag = sources[i]['fuv_mag']

    tmin_H = get_cooling_temp(age,m2,atm='H')
    tmin_He = get_cooling_temp(age,m2,atm='He')

    if m2 < 0.45:
        core = 'He'
    elif m2 < 1.1:
        core = 'CO'
    else:
        core = 'ONe'

    if np.ma.is_masked(nuv_mag):
        # print(f'no NUV data for {idx}')
        ''
    elif np.ma.is_masked(fuv_mag):
        nuv_lo,nuv,nuv_hi = get_primary_nuv_lims(t_lo,t,t_hi,m_lo,m,m_hi,r_lo,r,r_hi,parallax,meta,av)
        nuv_H_tarr= get_companion_RT_lims_nuv(nuv_mag,nuv_lo,nuv,nuv_hi,m2,parallax,av,core=core,atm='H')
        nuv_He_tarr = get_companion_RT_lims_nuv(nuv_mag,nuv_lo,nuv,nuv_hi,m2,parallax,av,core=core,atm='He')
        print(f'{core} WD {idx} NUV: {nuv_H_tarr:.0f} K (H), {nuv_He_tarr:.0f} K (He). Minimum: {(tmin_H + tmin_He)/2:.0f} K')
    if np.ma.is_masked(fuv_mag):
        # print(f'no FUV data for {idx}')
        ''
    else:
        fuv_lo,fuv,fuv_hi = get_primary_fuv_lims(t_lo,t,t_hi,m_lo,m,m_hi,r_lo,r,r_hi,parallax,meta,av)
        fuv_H_tarr = get_companion_RT_lims_fuv(fuv_mag,fuv_lo,fuv,fuv_hi,m2,parallax,av,core=core,atm='H')
        fuv_He_tarr = get_companion_RT_lims_fuv(fuv_mag,fuv_lo,fuv,fuv_hi,m2,parallax,av,core=core,atm='He')
        print(f'{core} WD {idx} FUV: {fuv_H_tarr:.0f} K (H), {fuv_He_tarr:.0f} K (He). Minimum: {(tmin_H + tmin_He)/2:.0f} K')

def get_sdb_RT_lims_nuv(obs_mag,nuv_lo,nuv,nuv_hi,radius,parallax,av):
    # teff_grid = np.linspace(5000,20000,100)
    # r_grid = radius
    teff_grid = 15000
    r_grid = np.linspace(0.001,3,50)
    def sdb_reddened_nuv(t,r,parallax,av): ## auxiliary: for solving Teff2
        sdb = blackbody_mod_table(teff=t, radius=r, parallax=parallax, parallax_err=0.1*parallax, Av=av)
        nuv = sdb['GALEX.NUV'][0]
        return nuv
    obs_flux = obs_nuv_flux(obs_mag)
    excess_hi = obs_flux - nuv_lo
    excess_lo = obs_flux - nuv_hi
    excess = obs_flux - nuv

    teff_grid,r_grid = np.meshgrid(teff_grid,r_grid)
    teff_grid = teff_grid.flatten()
    r_grid = r_grid.flatten()

    try:
        fvec = np.array([sdb_reddened_nuv(t,r,parallax,av)/excess -1 for t,r in zip(teff_grid,r_grid)])
        teff_grid = teff_grid[~np.isnan(fvec)]
        r_grid = r_grid[~np.isnan(fvec)]
        fvec = fvec[~np.isnan(fvec)]
        t,r = interp1d(fvec,teff_grid,kind='linear')(0),interp1d(fvec,r_grid,kind='linear')(0)
    except:
        t,r = np.sign(excess),np.sign(excess)
    return t,r

def get_sdb_RT_lims_fuv(obs_mag,fuv_lo,fuv,fuv_hi,radius,parallax,av):
    # teff_grid = np.linspace(5000,20000,100)
    # r_grid = radius
    teff_grid = 15000
    r_grid = np.linspace(0.001,3,50)
    def sdb_reddened_fuv(t,r,parallax,av): ## auxiliary: for solving Teff2
        sdb = blackbody_mod_table(teff=t, radius=r, parallax=parallax, parallax_err=0.1*parallax, Av=av)
        fuv = sdb['GALEX.FUV'][0]
        return fuv
    obs_flux = obs_fuv_flux(obs_mag)
    excess_hi = obs_flux - fuv_lo
    excess_lo = obs_flux - fuv_hi
    excess = obs_flux - fuv

    teff_grid,r_grid = np.meshgrid(teff_grid,r_grid)
    teff_grid = teff_grid.flatten()
    r_grid = r_grid.flatten()

    try:
        fvec = np.array([sdb_reddened_fuv(t,r,parallax,av)/excess -1 for t,r in zip(teff_grid,r_grid)])
        teff_grid = teff_grid[~np.isnan(fvec)]
        r_grid = r_grid[~np.isnan(fvec)]
        fvec = fvec[~np.isnan(fvec)]
        t,r = interp1d(fvec,teff_grid,kind='linear')(0),interp1d(fvec,r_grid,kind='linear')(0)
    except:
        t,r = np.sign(excess),np.sign(excess)
    return t,r

def fit_excess_to_bb(sources,idx,radius):
    i = np.where(sources['idx'] == idx)[0][0]
    t_lo,t,t_hi = sources[i]['teff1'] * 0.9,sources[i]['teff1'] ,sources[i]['teff1']*1.1
    m_lo,m,m_hi = sources[i]['m1'] - sources[i]['m1_err'], sources[i]['m1'], sources[i]['m1'] + sources[i]['m1_err']
    r_lo,r,r_hi = sources[i]['r1'] * 0.9, sources[i]['r1'], sources[i]['r1'] * 1.1
    parallax = sources[i]['parallax']
    meta = sources[i]['[Fe/H]']
    age = sources[i]['age']
    av = sources[i]['Av']
    m2 = sources[i]['m2']
    
    nuv_mag = sources[i]['nuv_mag']
    fuv_mag = sources[i]['fuv_mag']

    if np.ma.is_masked(nuv_mag):
        # print(f'no NUV data for {idx}')
        ''
    elif np.ma.is_masked(fuv_mag):
        nuv_lo,nuv,nuv_hi = get_primary_nuv_lims(t_lo,t,t_hi,m_lo,m,m_hi,r_lo,r,r_hi,parallax,meta,av)
        t_bb,r_bb = get_sdb_RT_lims_nuv(nuv_mag,nuv_lo,nuv,nuv_hi,radius,parallax,av)
        print(f'SDB {idx} NUV: {t_bb:.0f} K, {r_bb:.3f}'+ r'$R_\odot$')
        
    if np.ma.is_masked(fuv_mag):
        # print(f'no FUV data for {idx}')
        ''
    else:
        fuv_lo,fuv,fuv_hi = get_primary_fuv_lims(t_lo,t,t_hi,m_lo,m,m_hi,r_lo,r,r_hi,parallax,meta,av)
        t_bb,r_bb = get_sdb_RT_lims_fuv(fuv_mag,fuv_lo,fuv,fuv_hi,radius,parallax,av)
        print(f'SDB {idx} FUV: {t_bb:.0f} K, {r_bb:.3f}'+ r'$R_\odot$')

def calculate_uv_excess(sources, idx):
    i = np.where(sources['idx'] == idx)[0][0]
    t_lo, t, t_hi = sources[i]['teff1'] * 0.9, sources[i]['teff1'], sources[i]['teff1'] * 1.1
    m_lo, m, m_hi = sources[i]['m1'] - sources[i]['m1_err'], sources[i]['m1'], sources[i]['m1'] + sources[i]['m1_err']
    r_lo, r, r_hi = sources[i]['r1'] * 0.9, sources[i]['r1'], sources[i]['r1'] * 1.1
    parallax = sources[i]['parallax']
    meta = sources[i]['[Fe/H]']
    av = sources[i]['Av']
    
    nuv_mag = sources[i]['nuv_mag']
    fuv_mag = sources[i]['fuv_mag']
    
    nuv_lo, nuv, nuv_hi = get_primary_nuv_lims(t_lo, t, t_hi, m_lo, m, m_hi, r_lo, r, r_hi, parallax, meta, av)
    fuv_lo, fuv, fuv_hi = get_primary_fuv_lims(t_lo, t, t_hi, m_lo, m, m_hi, r_lo, r, r_hi, parallax, meta, av)
     
    def calculate_uv_excess_inner(obs_nuv_mag, obs_fuv_mag, nuv_lo, nuv, nuv_hi, fuv_lo, fuv, fuv_hi):
        nuv_flux = obs_nuv_flux(obs_nuv_mag)
        fuv_flux = obs_fuv_flux(obs_fuv_mag)
        
        nuv_excess = nuv_flux - nuv
        nuv_excess_hi = nuv_flux - nuv_lo
        nuv_excess_lo = nuv_flux - nuv_hi
        
        fuv_excess = fuv_flux - fuv
        fuv_excess_hi = fuv_flux - fuv_lo
        fuv_excess_lo = fuv_flux - fuv_hi
    
        return {
            'nuv_excess': nuv_excess,
            'fuv_excess': fuv_excess,
        }
    return calculate_uv_excess_inner(nuv_mag, fuv_mag, nuv_lo, nuv, nuv_hi, fuv_lo, fuv, fuv_hi)

In [12]:
# fit_excess_to_bb(sources,180,radius=0.10)
sources = Table.read('../table_C.fits', format='fits')
excess = [1,15,29,54,57,58,69,83,131,132,133,139,146,147,152,156,157,217,232,235,237,239]
wd = [54,57,69,83,131,132,133,156,157,235,239]
sdb_or_flare = [1,15,29,58,139,146,147,152,217,232,237]
# for idx in wd:
    # assess_uv_excess(sources,idx)
    # fit_excess_to_bb(sources,idx,radius=0.15)



# Fit VIS+IR to primary
* Yarin- work on these cells.
* First cell: fitting function definitions. Run it to understand what parameters are used.
* Second cell: unfinished code for you- gets rough Gaia appx for extinction and metallicity.

In [13]:
def fit_src_kurucz(sources,idx,no_IR=False):
    j = np.where(sources['idx'] == idx)[0][0]
    obs_tbl = get_photometry_single_source(Table(sources[j]))
    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[0][bnd] for bnd in bnds])
    flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds])

    # fixed parameters
    parallax = sources[j]['parallax']
    meta = sources[j]['[Fe/H]']
    av = sources[j]['Av']
    m1 = sources[j]['m1']

    def model_flux(x,t,r):
        ms = get_ms_sed(t,m1,r,meta,parallax)
        ms = redden_model_table(ms,t,av)
        model_flux = np.array([ms[bnd][0] for bnd in bands_table['wd_band']])
        wl = list(bands_table['lambda_eff'])
        cut = np.isin(wl,x)
        model_flux = model_flux[cut]
        return model_flux
    
    # t1_guess = sources[j]['teff1']
    # r1_guess = sources[j]['r1']
    t1_guess = 6000
    r1_guess = 1.5

    cut = (wl>3000) & ~np.isnan(flux) & ~np.isnan(flux_err) ## no UV
    cut = cut.astype(bool)
    y = flux[cut]
    y_err = flux_err[cut]
    x = wl[cut]
    
    if no_IR:
        cut = cut & (wl<10000)
    cut = cut.astype(bool)
    yy = flux[cut]
    yy_err = flux_err[cut]
    xx = wl[cut]
    
    res = curve_fit(model_flux,xx,yy,p0=[t1_guess,r1_guess],bounds=[(3000,0.4),(15000,3.0)],sigma=yy_err,absolute_sigma=True)
    t1_fit,r1_fit = res[0]
    t1_err,r1_err = np.sqrt(np.diag(res[1]))
    chi2 = np.sum((y - model_flux(x,t1_fit,r1_fit))**2 / y_err**2)
    dof = len(y) - 2
    redchi2 = chi2 / dof
    N_ir = np.sum(~np.isnan(flux[wl>10000]))
    N_opt = np.sum(~np.isnan(flux[(wl<10000) & (wl>3000)]))
    return t1_fit,t1_err,r1_fit,r1_err,redchi2,N_ir,N_opt

def fit_src_wd(sources,idx,no_uv=False):
    j = np.where(sources['idx'] == idx)[0][0]
    obs_tbl = get_photometry_single_source(Table(sources[j]))
    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[0][bnd] for bnd in bnds])
    flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds])

    # fixed parameters
    parallax = sources[j]['parallax']
    av = sources[j]['Av']
    atm = 'H'

    if m2 < 0.45:
        core = 'He'
    elif m2 < 1.1:
        core = 'CO'
    else:
        core = 'ONe'

    def model_flux(x,t,m2):
        wd = get_wd_sed(t,m2,parallax,core,atm)
        model_flux = np.array([wd[bnd][0] for bnd in bands_table['wd_band']])
        wl = list(bands_table['lambda_eff'])
        cut = np.isin(wl,x)
        model_flux = model_flux[cut]
        return model_flux
    
    t2_guess = 10000
    m2_guess = 0.6

    cut = ~np.isnan(flux) & ~np.isnan(flux_err)
    if no_uv:
        cut = cut & (wl>3000)
    cut = cut.astype(bool)

    y = flux[cut]
    y_err = flux_err[cut]
    x = wl[cut]

    res = curve_fit(model_flux,x,y,p0=[t2_guess,m2_guess],bounds=[(5000,0.4),(50000,1.4)],sigma=y_err,absolute_sigma=True)
    t2_fit,m2_fit = res[0]
    t2_err,m2_err = np.sqrt(np.diag(res[1]))
    chi2 = np.sum((y - model_flux(x,t2_fit,m2_fit,core,atm))**2 / y_err**2)
    dof = len(y) - 2
    redchi2 = chi2 / dof
    return t2_fit,t2_err,m2_fit,m2_err,redchi2

In [None]:
# For Yarin

query = '''SELECT source_id, ra, dec, parallax, parallax_error, mh_gspphot, ag_gspphot FROM gaiadr3.gaia_source WHERE source_id = 1894996161487458048'''
job = Gaia.launch_job(query)
res = job.get_results()

tbl = Table({'source_id':[1894996161487458048],'ra':[res['ra'][0]],'dec':[res['dec'][0]],
             'parallax':[res['parallax'][0]],'parallax_error':[res['parallax_error'][0]],'[Fe/H]':[res['mh_gspphot'][0]],'Av':[0],'m1':[1]})
tbl.add_row(tbl[0])
tbl['teff1'] = np.full(len(tbl),np.nan)
tbl['teff1_err'] = np.full(len(tbl),np.nan)
tbl['r1'] = np.full(len(tbl),np.nan)
tbl['r1_err'] = np.full(len(tbl),np.nan)
tbl['idx'] = np.arange(len(tbl))
fit_src_wd(tbl,0,True)
# teff1, teff1_err, r1, r1_err, redchi2 = fit_src_kurucz(tbl,0,True)
# tbl[0]['teff1'] = teff1
# tbl[0]['teff1_err'] = teff1_err
# tbl[0]['r1'] = r1
# tbl[0]['r1_err'] = r1_err
# plot_kurucz_with_residuals(tbl,0,True,True,True)


# table for paper

In [None]:
sources = Table.read('../table_C.fits', format='fits')
sources['teff1'] = np.full(len(sources),np.nan)
sources['teff1_err'] = np.full(len(sources),np.nan)
sources['r1'] = np.full(len(sources),np.nan)
sources['r1_err'] = np.full(len(sources),np.nan)
no_IR = True
for i,idx in enumerate(sources['idx']):
    teff1, teff1_err, r1, r1_err, redchi2 = fit_src_kurucz(sources,idx,no_IR=no_IR)
    sources[i]['teff1'] = teff1
    sources[i]['teff1_err'] = teff1_err
    sources[i]['r1'] = r1
    sources[i]['r1_err'] = r1_err
    plot_kurucz_with_residuals(sources,idx,plot = False, save= True,no_IR=no_IR)

In [14]:
sources = Table.read('../table_C.fits', format='fits')
chi2_opt = []
n_opt = []
chi2_ir = []
n_ir = []
for idx in tqdm(sources['idx']):
    teff1, teff1_err, r1, r1_err, chi2, nir, nopt = fit_src_kurucz(sources,idx,no_IR=True)
    chi2_opt.append(chi2)
    n_opt.append(nopt)
    n_ir.append(nir)
    teff1, teff1_err, r1, r1_err, chi2, nir, nopt = fit_src_kurucz(sources,idx,no_IR=False)
    chi2_ir.append(chi2)

100%|██████████| 40/40 [02:19<00:00,  3.50s/it]


In [19]:
import numpy as np

def create_latex_table(data, chi2_opt, chi2_ir, n_opt, n_ir, sources):
    # Start the LaTeX table
    latex_code = r"""
\begin{deluxetable*}{cccccccccc} 
\tablecaption{SED Fitting Results} 
\tabletypesize{\small}
\tablewidth{0pt}
\tablecolumns{9}
\tablehead{
\colhead{Idx} & \colhead{N$_\text{opt}$} & \colhead{N$_\text{IR}$} & \colhead{$\chi^2_\text{opt}$} & \colhead{$\chi^2_\text{ir}$} & \colhead{$T_\text{eff,1}$} &
 \colhead{$R_1$} & \colhead{$T_\text{eff,2}$} & \colhead{NUV Excess} & \colhead{FUV Excess} \\ [-0.5em]
\colhead{} & \colhead{\scriptsize Data points in optical} & \colhead{\scriptsize Data points in IR} & 
\colhead{\scriptsize Optical only fit} & 
\colhead{\scriptsize Optical+IR fit} & 
\colhead{\scriptsize (K)} & 
\colhead{\scriptsize ($R_\odot$)} & 
\colhead{\scriptsize (K)} &
\colhead{\scriptsize (erg cm$^{-2}$ s$^{-1}$ $\AA^{-1}$)} & 
\colhead{\scriptsize (erg cm$^{-2}$ s$^{-1}$ $\AA^{-1}$)}}
\startdata
"""
    # Loop through the rows in the data
    for i, row in enumerate(data):
        idx = row['idx']
        n_opt_val = n_opt[i]
        n_ir_val = n_ir[i]
        chi2_opt_val = round(chi2_opt[i], 2)
        chi2_ir_val = round(chi2_ir[i], 2)
        if idx in [101,140,160,192,194,282]:
            chi2_ir_val = '\\nodata'
        excess = calculate_uv_excess(sources, idx)
        nuv_excess = excess['nuv_excess']
        fuv_excess = excess['fuv_excess']
        teff1 = round(sources[i]['teff1'], 2)
        teff1_err = round(sources[i]['teff1_err'], 2)
        r1 = round(sources[i]['r1'], 2)
        r1_err = round(sources[i]['r1_err'], 2)
        sig_dig_t = -int(np.floor(np.log10(teff1_err)))
        sig_dig_r = -int(np.floor(np.log10(r1_err)))
        
        teff1_str = f"${int(teff1.round(sig_dig_t))} \pm {int(teff1_err.round(sig_dig_t))}$"
        r1_str = f"${r1.round(sig_dig_r)} \pm {r1_err.round(sig_dig_r)}$"
        teff2_str = '\\nodata'
        if os.path.exists(f'../data/chains/chains_{idx}.csv'):
            chain = Table.read(f'../data/chains/chains_{idx}.csv')
            teff2 = np.median(chain['teff2'])
            teff2_err = np.std(chain['teff2'])
            sig_dit_t2 = -int(np.floor(np.log10(teff2_err)))
            teff2_str = f"${int(teff2.round(sig_dit_t2))} \pm {int(teff2_err.round(sig_dit_t2))}$"

        nuv_excess_val = f"{nuv_excess:.2e}" if not (np.ma.is_masked(nuv_excess) or nuv_excess <0) else "\\nodata"
        fuv_excess_val = f"{fuv_excess:.2e}" if not (np.ma.is_masked(fuv_excess) or fuv_excess <0) else "\\nodata"
        # Format the LaTeX row
        latex_code += f"{idx} & {n_opt_val} & {n_ir_val} & {chi2_opt_val} & {chi2_ir_val} & {teff1_str} & {r1_str} & {teff2_str} & {nuv_excess_val} & {fuv_excess_val} \\\\ \n"

    # Close the LaTeX table
    latex_code += r"""
\enddata
\end{deluxetable*}
"""

    return latex_code

# Example usage:
latex_table = create_latex_table(sources, chi2_opt, chi2_ir, n_opt, n_ir, sources)
print(latex_table)



\begin{deluxetable*}{cccccccccc} 
\tablecaption{SED Fitting Results} 
\tabletypesize{\small}
\tablewidth{0pt}
\tablecolumns{9}
\tablehead{
\colhead{Idx} & \colhead{N$_\text{opt}$} & \colhead{N$_\text{IR}$} & \colhead{$\chi^2_\text{opt}$} & \colhead{$\chi^2_\text{ir}$} & \colhead{$T_\text{eff,1}$} &
 \colhead{$R_1$} & \colhead{$T_\text{eff,2}$} & \colhead{NUV Excess} & \colhead{FUV Excess} \\ [-0.5em]
\colhead{} & \colhead{\scriptsize Data points in optical} & \colhead{\scriptsize Data points in IR} & 
\colhead{\scriptsize Optical only fit} & 
\colhead{\scriptsize Optical+IR fit} & 
\colhead{\scriptsize (K)} & 
\colhead{\scriptsize ($R_\odot$)} & 
\colhead{\scriptsize (K)} &
\colhead{\scriptsize (erg cm$^{-2}$ s$^{-1}$ $\AA^{-1}$)} & 
\colhead{\scriptsize (erg cm$^{-2}$ s$^{-1}$ $\AA^{-1}$)}}
\startdata
6 & 8 & 6 & 0.79 & 0.68 & $8200 \pm 200$ & $1.99 \pm 0.06$ & \nodata & \nodata & \nodata \\ 
16 & 8 & 6 & 1.03 & 0.85 & $8100 \pm 200$ & $1.86 \pm 0.05$ & \nodata & \nodata & \nodata \\

In [17]:
from astropy.table import Table, Column
import numpy as np

def create_astropy_table(data, chi2_opt, chi2_ir, n_opt, n_ir, sources):
    idx_col = Column(name='Idx', data=[row['idx'] for row in data], dtype=int)
    n_opt_col = Column(name='N_opt', data=n_opt, dtype=int, unit='Data points in optical')
    n_ir_col = Column(name='N_ir', data=n_ir, dtype=int, unit='Data points in IR')
    chi2_opt_col = Column(name='chi2_opt', data=[np.nan if idx in [101,140,160,192,194,282] else round(val, 2) for idx, val in zip([row['idx'] for row in data], chi2_opt)], dtype=float, unit='Optical only fit')
    chi2_ir_col = Column(name='chi2_ir', data=[round(val, 2) for val in chi2_ir], dtype=float, unit='Optical+IR fit')
    teff1_col = Column(name='Teff1', data=[round(sources[i]['teff1'], 2) for i in range(len(sources))], dtype=float, unit='K')
    teff1_err_col = Column(name='Teff1_err', data=[round(sources[i]['teff1_err'], 2) for i in range(len(sources))], dtype=float, unit='K')
    r1_col = Column(name='R1', data=[round(sources[i]['r1'], 2) for i in range(len(sources))], dtype=float, unit='R_sun')
    r1_err_col = Column(name='R1_err', data=[round(sources[i]['r1_err'], 2) for i in range(len(sources))], dtype=float, unit='R_sun')
    teff2_col = Column(name='Teff2', data=[np.nan if not os.path.exists(f'../data/chains/chains_{idx}.csv') else np.median(Table.read(f'../data/chains/chains_{idx}.csv')['teff2']) for idx in [row['idx'] for row in data]], dtype=float, unit='K')
    nuv_excess_col = Column(name='NUV_Excess', data=[calculate_uv_excess(sources, idx)['nuv_excess'] if not (np.ma.is_masked(calculate_uv_excess(sources, idx)['nuv_excess']) or calculate_uv_excess(sources, idx)['nuv_excess'] < 0) else np.nan for idx in [row['idx'] for row in data]], dtype=float, unit='erg cm^-2 s^-1 A^-1')
    fuv_excess_col = Column(name='FUV_Excess', data=[calculate_uv_excess(sources, idx)['fuv_excess'] if not (np.ma.is_masked(calculate_uv_excess(sources, idx)['fuv_excess']) or calculate_uv_excess(sources, idx)['fuv_excess'] < 0) else np.nan for idx in [row['idx'] for row in data]], dtype=float, unit='erg cm^-2 s^-1 A^-1')

    astropy_table = Table([idx_col, n_opt_col, n_ir_col, chi2_opt_col, chi2_ir_col, teff1_col, teff1_err_col, r1_col, r1_err_col, teff2_col, nuv_excess_col, fuv_excess_col])
    return astropy_table

# Example usage:
astropy_table = create_astropy_table(sources, chi2_opt, chi2_ir, n_opt, n_ir, sources)
astropy_table.write('sed_results.fits', format='fits', overwrite=True)

# Binary SED fit

In [25]:
import emcee
from multiprocessing import Pool

sources = Table.read('../table_C.fits')
idx = 283
core = 'He'
atm = 'H'

i = np.where(sources['idx'] == idx)[0][0]

obs_tbl = get_photometry_single_source(Table(sources[i]))
bnds = list(bands_table['band'])
wl = np.array(list(bands_table['lambda_eff']))
flux = np.array([obs_tbl[0][bnd] for bnd in bnds])
flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds])

parallax_guess = sources[i]['parallax']
parallax_err = sources[i]['parallax_error']
meta_guess = sources[i]['[Fe/H]']
meta_err = sources[i]['e_[Fe/H]']
av_guess = sources[i]['Av']
# av_err = sources[i]['e_Av']
av_err = 0.1
age = sources[i]['age']

teff1_guess = sources[i]['teff1']
teff1_err = sources[i]['teff1_err']
r1_guess = sources[i]['r1']
r1_err = sources[i]['r1_err']
m1_guess = sources[i]['m1']
m1_err = sources[i]['m1_err']
m2_guess = sources[i]['m2']
m2_err = sources[i]['m2_err']
teff2_min = get_cooling_temp(age,m2_guess,core,atm)
teff2_arr = [7000,7300,12000]
teff2_lo, teff2_guess, teff2_hi = teff2_arr
teff2_min = np.nanmin([teff2_min,teff2_lo])


# ## mock data
# mock_ms = get_ms_sed(teff1_guess,m1_guess,r1_guess,meta_guess,parallax_guess)
# mock_ms = redden_model_table(mock_ms,teff1_guess,av_guess)
# mock_wd = get_wd_sed(teff2_guess,m2_guess,parallax_guess,core,atm)
# mock_wd = redden_model_table(mock_wd,teff2_guess,av_guess)
# flux = np.array([mock_wd[bnd][0] + mock_ms[bnd][0] for bnd in bands_table['wd_band']])
# flux_err = 0.1 * flux
# flux = np.random.normal(flux,flux_err*1.05)
# ## mock data

def init_guess(nwalkers):
    t1 = np.random.normal(teff1_guess,teff1_err,nwalkers)
    t2 = np.random.uniform(teff2_lo, teff2_hi, nwalkers)
    r1 = np.random.normal(r1_guess,r1_err,nwalkers)
    covm = np.diag([m1_err**2,m2_err**2,av_err**2,parallax_err**2,meta_err**2])
    m1,m2,av,parallax,meta = np.random.multivariate_normal([m1_guess,m2_guess,av_guess,parallax_guess,meta_guess],covm,nwalkers).T
    return np.array([t1,t2,r1,m1,m2,av,parallax,meta]).T

def log_prior_parallax(parallax): ## Gaussian prior on parallax, truncated at 3 sigma
    snr2 = (parallax - parallax_guess)**2 / parallax_err**2
    if snr2 > 3**2: # truncate at 3 sigma
        return -np.inf
    return -0.5 * snr2
def log_prior_av(av): ## Gaussian prior on Av
    if av < 0:
        return -np.inf
    snr2 = (av - av_guess)**2 / av_err**2
    # if snr2 > 3**2: # truncate at 3 sigma
    #     return -np.inf
    return -0.5 * snr2
def log_prior_meta(meta): ## Gaussian prior on [Fe/H], truncated at 3 sigma
    if meta >= 0.4 or meta <= -0.9:
        return -np.inf
    snr2 = (meta - meta_guess)**2 / meta_err**2
    if snr2 > 3**2: # truncate at 3 sigma
        return -np.inf
    return -0.5 * snr2
def log_prior_m1(m1): ## Gaussian prior on m1, truncated at 3 sigma
    snr2 = (m1 - m1_guess)**2 / m1_err**2
    if snr2 > 3**2: # truncate at 3 sigma
        return -np.inf
    return -0.5 * snr2
def log_prior_m2(m2): ## Gaussian prior on m2, truncated at 3 sigma
    snr2 = (m2 - m2_guess)**2 / m2_err**2
    if snr2 > 3**2: # truncate at 3 sigma
        return -np.inf
    return -0.5 * snr2
def log_prior_r1(r1): ## Uniform prior on r1
    if 0.1 < r1 < 5:
        return 0
    return -np.inf
def log_prior_teff1(teff1): ## Uniform prior on teff1+-
    if 3500 < teff1 < 15000:
        return 0
    return -np.inf
def log_prior_teff2(teff2): ## Uniform prior on teff2
    if teff2_min < teff2 < 80000:
        return 0
    return -np.inf
def log_prior(params):
    teff1,teff2,r1,m1,m2,av,parallax,meta = params
    lp_teff1 = log_prior_teff1(teff1)
    lp_teff2 = log_prior_teff2(teff2)
    lp_r1 = log_prior_r1(r1)
    lp_av = log_prior_av(av)
    if np.isinf(lp_teff1 + lp_teff2 + lp_r1 + lp_av):
        return -np.inf
    lp_m1 = log_prior_m1(m1)
    lp_m2 = log_prior_m2(m2)
    lp_parallax = log_prior_parallax(parallax)
    lp_meta = log_prior_meta(meta)
    return lp_teff1 + lp_teff2 + lp_r1 + lp_m1 + lp_m2 + lp_av + lp_parallax + lp_meta

def log_likelihood(params,wl,flux,flux_err):
    teff1,teff2,r1,m1,m2,av,parallax,meta = params
    ms = get_ms_sed(teff1,m1,r1,meta,parallax)
    ms = redden_model_table(ms,teff1,av)
    wd = get_wd_sed(teff2,m2,parallax,core,atm)
    wd = redden_model_table(wd,teff2,av)
    model_flux = np.array([ms[bnd][0] + wd[bnd][0] for bnd in bands_table['wd_band']])
    cut = ~np.isnan(flux) & ~np.isnan(model_flux)
    flux = flux[cut]
    flux_err = flux_err[cut]
    model_flux = model_flux[cut]
    sigma2 = flux_err**2 
    return -0.5 * np.sum((flux - model_flux)**2 / sigma2 + np.log(sigma2))

def log_probability(params,wl,flux,flux_err):
    lp = log_prior(params)
    if not np.isfinite(lp):
        return -np.inf
    return lp + log_likelihood(params,wl,flux,flux_err)

pos = init_guess(30)
nwalkers, ndim = pos.shape
nsteps = 10000

with Pool() as pool:
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability, pool=pool, args=(wl,flux,flux_err),moves=[
        (emcee.moves.StretchMove(a=2), 1.0),
        (emcee.moves.DESnookerMove(), 0.0),
    ])
    sampler.run_mcmc(pos, nsteps, progress=True)

100%|██████████| 10000/10000 [54:17<00:00,  3.07it/s] 


In [67]:
# MS + BB fit
import emcee
from multiprocessing import Pool

sources = Table.read('../table_C.fits')
idx = 180
i = np.where(sources['idx'] == idx)[0][0]

obs_tbl = get_photometry_single_source(Table(sources[i]))
bnds = list(bands_table['band'])
wl = np.array(list(bands_table['lambda_eff']))
flux = np.array([obs_tbl[0][bnd] for bnd in bnds])
flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds])

parallax_guess = sources[i]['parallax']
parallax_err = sources[i]['parallax_error']
meta_guess = sources[i]['[Fe/H]']
meta_err = sources[i]['e_[Fe/H]']
av_guess = sources[i]['Av']
av_err = sources[i]['e_Av']
age = sources[i]['age']

teff1_guess = sources[i]['teff1']
teff1_err = sources[i]['teff1_err']
r1_guess = sources[i]['r1']
r1_err = sources[i]['r1_err']
m1_guess = sources[i]['m1']
m1_err = sources[i]['m1_err']

r2_guess = 0.10

teff2_arr = [12000,12500,13000]
teff2_lo, teff2_guess, teff2_hi = teff2_arr

def init_guess(nwalkers):
    t1 = np.random.normal(teff1_guess,teff1_err,nwalkers)
    t2 = np.random.uniform(teff2_lo, teff2_hi, nwalkers)
    r1 = np.random.normal(r1_guess,r1_err,nwalkers)
    covm = np.diag([m1_err**2,av_err**2,parallax_err**2,meta_err**2])
    m1,av,parallax,meta = np.random.multivariate_normal([m1_guess,av_guess,parallax_guess,meta_guess],covm,nwalkers).T
    return np.array([t1,t2,r1,m1,av,parallax,meta]).T

def log_prior_parallax(parallax): ## Gaussian prior on parallax, truncated at 3 sigma
    snr2 = (parallax - parallax_guess)**2 / parallax_err**2
    if snr2 > 3**2: # truncate at 3 sigma
        return -np.inf
    return -0.5 * snr2
def log_prior_av(av): ## Gaussian prior on Av
    if av < 0:
        return -np.inf
    snr2 = (av - av_guess)**2 / av_err**2
    # if snr2 > 3**2: # truncate at 3 sigma
    #     return -np.inf
    return -0.5 * snr2
def log_prior_meta(meta): ## Gaussian prior on [Fe/H], truncated at 3 sigma
    if meta >= 0.4 or meta <= -0.9:
        return -np.inf
    snr2 = (meta - meta_guess)**2 / meta_err**2
    if snr2 > 3**2: # truncate at 3 sigma
        return -np.inf
    return -0.5 * snr2
def log_prior_m1(m1): ## Gaussian prior on m1, truncated at 3 sigma
    snr2 = (m1 - m1_guess)**2 / m1_err**2
    if snr2 > 3**2: # truncate at 3 sigma
        return -np.inf
    return -0.5 * snr2
def log_prior_r1(r1): ## Uniform prior on r1
    if 0.1 < r1 < 5:
        return 0
    return -np.inf
def log_prior_teff1(teff1): ## Uniform prior on teff1+-
    if 3500 < teff1 < 15000:
        return 0
    return -np.inf
def log_prior_teff2(teff2): ## Uniform prior on teff2
    if 0< teff2 < 80000:
        return 0
    return -np.inf
def log_prior(params):
    teff1,teff2,r1,m1,av,parallax,meta = params
    lp_teff1 = log_prior_teff1(teff1)
    lp_teff2 = log_prior_teff2(teff2)
    lp_r1 = log_prior_r1(r1)
    lp_av = log_prior_av(av)
    if np.isinf(lp_teff1 + lp_teff2 + lp_r1 + lp_av):
        return -np.inf
    lp_m1 = log_prior_m1(m1)
    lp_parallax = log_prior_parallax(parallax)
    lp_meta = log_prior_meta(meta)
    return lp_teff1 + lp_teff2 + lp_r1 + lp_m1 + lp_av + lp_parallax + lp_meta

def log_likelihood(params,wl,flux,flux_err):
    teff1,teff2,r1,m1,av,parallax,meta = params
    ms = get_ms_sed(teff1,m1,r1,meta,parallax)
    ms = redden_model_table(ms,teff1,av)
    bb = blackbody_mod_table(teff=teff2, radius=r2_guess, parallax=parallax, parallax_err=parallax_err, Av=av)
    model_flux = np.array([ms[wdbnd][0] + bb[bnd][0] for wdbnd,bnd in zip(bands_table['wd_band'],bands_table['band'])])
    cut = ~np.isnan(flux) & ~np.isnan(model_flux)
    flux = flux[cut]
    flux_err = flux_err[cut]
    model_flux = model_flux[cut]
    sigma2 = flux_err**2 
    return -0.5 * np.sum((flux - model_flux)**2 / sigma2 + np.log(sigma2))

def log_probability(params,wl,flux,flux_err):
    lp = log_prior(params)
    if not np.isfinite(lp):
        return -np.inf
    return lp + log_likelihood(params,wl,flux,flux_err)

pos = init_guess(30)
nwalkers, ndim = pos.shape
nsteps = 500
with Pool() as pool:
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability, pool=pool, args=(wl,flux,flux_err),moves=[
        (emcee.moves.StretchMove(a=2), 1.0),
        (emcee.moves.DESnookerMove(), 0.0),
    ])
    sampler.run_mcmc(pos, nsteps, progress=True)

100%|██████████| 500/500 [04:23<00:00,  1.90it/s]


In [26]:
fig, axes = plt.subplots(3, figsize=(10, 7), sharex=True)
samples = sampler.get_chain()
labels = ["teff1", "teff2", "r1", "m1",'m2', "Av", "parallax", "[Fe/H]"]
ndim = samples.shape[2]
for i in range(3):
    ax = axes[i]
    ax.plot(samples[:, :, i], "k", alpha=0.3)
    ax.set_xlim(0, len(samples))
    ax.set_ylabel(labels[i])
    
axes[-1].set_xlabel("step number")

Text(0.5, 0, 'step number')

In [None]:
tau = sampler.get_autocorr_time()
print(tau)

In [27]:
from corner import corner
flat_samples = sampler.get_chain(discard = 500,thin = 75,flat=True)
# flat_samples = sampler.get_chain(discard = 100,flat=True)
_=corner(flat_samples.reshape(-1,ndim),labels=labels,show_titles=True,truths=[teff1_guess,teff2_guess,r1_guess,m1_guess,m2_guess,av_guess,parallax_guess,meta_guess],labelpad=0.1)

In [37]:
t1 = flat_samples[:,0]
t2 = flat_samples[:,1]
r1 = flat_samples[:,2]
np.corrcoef(t1, r1)

Table(flat_samples, names=labels).write(f'../data/class1_chains/chains_{idx}.csv',overwrite=True)

In [38]:
teff1_fit = np.median(flat_samples[:,0])
teff2_fit = np.median(flat_samples[:,1])
r1_fit = np.median(flat_samples[:,2])
m1_fit = np.median(flat_samples[:,3])
m2_fit = np.median(flat_samples[:,4])
av_fit = np.median(flat_samples[:,5])
parallax_fit = np.median(flat_samples[:,6])
meta_fit = np.median(flat_samples[:,7])

j = np.where(sources['idx'] == idx)[0][0]
src_fit = Table(sources[j])

src_fit[0]['m1'] = m1_fit
src_fit[0]['Av'] = av_fit
src_fit[0]['parallax'] = parallax_fit
src_fit[0]['[Fe/H]'] = meta_fit

# plot_mswd_vs_obs(sources,idx,(teff1_guess,r1_guess),(teff2_guess,core,atm),plot=True,save=False)
# plot_mswd_vs_obs(src_fit,idx,(teff1_fit,r1_fit),(teff2_fit,core,atm),plot=True,save=False)
plot_mswd_with_residuals(src_fit,idx,(teff1_fit,r1_fit),(teff2_fit,core,atm),plot=True,save=True)
# plot_msbb_with_residuals(sources,idx,(teff1_guess,r1_guess),(teff2_fit,r2_guess),plot=True,save=True)

# Misc- Swift SNR, latex tables

In [None]:
# zp = 7.5e-16
# print('idx cts/s texp')
# for i in range(len(sources)):
#     av = sources[i]['Av']
#     anuv = extinction.get_Galex_extinction(av,0,sources[i]['teff2_min'])[0]
#     m2 = sources[i]['m2']
#     age = sources[i]['age']
#     nuv = m_agecool_to_mag(m2,age*1e-3)
#     parallax = sources[i]['parallax']
#     nuv_apparent = nuv + 5 * np.log10(1000/parallax) - 5 + anuv
#     c = obs_nuv_flux(nuv_apparent,av)/zp
#     texp = 20**2 / c
#     if texp < 1000 and np.ma.is_masked(sources[i]['nuv_mag']):
#         print(sources[i]['idx'],c.round(2),int(texp))
    

In [None]:
# sources = sources[sources['idx'] == 99]

# sources['phot_g_mean_mag'] = sources['phot_g_mean_mag'].round(1)
# sources['ra'] = sources['ra'].round(5)
# sources['dec'] = sources['dec'].round(5)
# sources['distance'] = (1000/sources['parallax']).round(0)
# sources['distance_err'] = (1000/sources['parallax']**2 * sources['parallax_error']).round(0)
# sources['D'] = [f'{d}±{d_err}' for d,d_err in zip(sources['distance'],sources['distance_err'])]
# sources['nuv_mag'] = sources['nuv_mag'].round(1)
# sources['m1'] = sources['m1'].round(2)
# sources['m1_err'] = sources['m1_err'].round(2)
# sources['m2'] = sources['m2'].round(2)
# sources['m2_err'] = sources['m2_err'].round(2)
# sources['r1'] = sources['r1'].round(2)
# sources['r1_err'] = sources['r1_err'].round(2)
# sources['teff1'] = np.array(sources['teff1'].round(-1),dtype = int)
# sources['teff1_err'] = np.array(sources['teff1_err'].round(-1),dtype = int)
# sources['teff2'] = np.array(sources['teff2'].round(-1),dtype = int)
# sources['age'] = (10**sources['log_age_50'] * 1e-6).round(0)
# sources['M1'] = [f'{m:.2f}±{m_err:.2f}' for m,m_err in zip(sources['m1'],sources['m1_err'])]
# sources['R1'] = [f'{r:.2f}±{r_err:.2f}' for r,r_err in zip(sources['r1'],sources['r1_err'])]
# sources['M2'] = [f'{m:.2f}±{m_err:.2f}' for m,m_err in zip(sources['m2'],sources['m2_err'])]
# sources['T1'] = [f'{t}±{t_err}' for t,t_err in zip(sources['teff1'],sources['teff1_err'])]
# sources['[Fe/H]'] = sources['[Fe/H]'].round(2)
# sources['Av'] = sources['Av'].round(2)
# sources['age'] = sources['age'].round(0)   
# print(sources['T1','R1','M2','teff2','[Fe/H]','Av','age'].to_pandas().to_latex(index=False))

# Plotting

In [None]:
sources = Table.read('../table_C.fits',format='fits')

# idx = 1
# i = np.where(sources['idx'] == idx)[0][0]
i = 5
idx = sources[i]['idx']

m1 = sources[i]['m1']
m1_err = sources[i]['m1_err']
m2 = sources[i]['m2']
if m2>1.25:
    m2 = sources[i]['m2'] - sources[i]['m2_err']
m2_err = sources[i]['m2_err']

parallax = sources[i]['parallax']
parallax_err = sources[i]['parallax_error']

teff1 = sources[i]['teff1']
# teff2 = sources[i]['teff2']

meta = sources[i]['[Fe/H]']

# wl, fl = plot_binary_wd(teff1,m1,m1_err,teff2,m2,m2_err,parallax,parallax_err,meta,gratings = [1105,'galex'])
plot_kurucz_vs_obs(sources,idx)
# plot_sed_ms_wd(teff1,m1,m1_err,teff2,m2,m2_err,parallax,parallax_err,meta)
# plot_binary_ms(8000,m1,m1_err,6000,m2,m2_err,parallax,parallax_err)
# plot_triple_ms(9000,m1,m1_err,3500,m2,m2_err,parallax,parallax_err)


# tbl = Table({'wavelength':wl,'flux':fl})
# cond = (tbl['wavelength'] > 925) & (tbl['wavelength'] < 80000)
# tbl = tbl[cond]
# tbl.write('../SBA28.dat',format='ascii',overwrite=True)