In [2]:
import numpy as np
from astropy.table import Table
from astropy import table
from astroquery.vizier import Vizier
from astroquery.gaia import Gaia
from astroquery.mast import Catalogs
from astropy.coordinates import SkyCoord
from astropy.io.votable import parse
from astropy.constants import G, M_sun, pc, R_sun, h, c, k_B
# from dust_extinction.parameter_averages import G23
import extinction
import astropy.units as u
from scipy.optimize import curve_fit, bisect
from scipy.interpolate import LinearNDInterpolator, interp1d
import WD_models
import os.path
from os import listdir
import matplotlib.pyplot as plt
from csv import writer
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')

figwidth_single = 89/25.4  # [inch]
figwidth_double = 183/25.4  # [inch]
%matplotlib tk

# Fitting SEDs

## Functions

In [3]:
## filter zero points, effective wavelengths, column names

## f = f0 * 10^(-0.4*(m-m0)) (erg/s/cm^2/AA) 

## wise_f0 = (np.array([309.540,171.787,31.674]) * u.Jy.to(u.erg/u.s/u.cm**2/u.Hz) * u.erg/u.s/u.cm**2/u.Hz * c.cgs / (np.array([33526,46030,115608]) * u.AA)**2).to(u.erg/u.s/u.cm**2/u.AA)

## gaia_f0 = np.array([3.009167e-21,1.346109e-21,1.638483e-21]) * (u.W/u.m**2/u.nm).to(u.erg/u.s/u.cm**2/u.AA)

# two_mass_f0 = np.array([3.129e-13,1.133e-13,4.283e-14]) * (u.W / u.cm**2 / u.micron).to(u.erg/u.s/u.cm**2/u.AA)


twomass_zp_table = Table({'band': ['2MASS.J','2MASS.H', '2MASS.Ks'],
                        'col':['Jmag','Hmag','Kmag'],
                        'err_col':['e_Jmag','e_Hmag','e_Kmag'],
                        'lambda_eff':[12350,16620,21590],
                        'f0':[3.129e-10,1.133e-10,4.283e-11],
                        'm0':[0,0,0],
                        'wd_f0':[3.106e-10,1.143e-10,4.206e-11], ## bedard f0 different than VOSA
                        'wd_lambda_eff':[12350,16620,21590]}) 

wise_zp_table = Table({'band': ['WISE.W1','WISE.W2','WISE.W3'],
                       'col':['W1mag','W2mag','W3mag'],
                          'err_col':['e_W1mag','e_W2mag','e_W3mag'],
                          'lambda_eff':[33526,46028,115608],
                                'f0':[8.18e-12,2.42e-12,6.52e-14],
                                'm0':[0,0,0],
                                'wd_f0':[8.18e-12,2.42e-12,6.52e-14],
                                'wd_lambda_eff':[33526,46028,115608]}) ## bedard doesn't give WISE f0

gaia_zp_table = Table({'band': ['GAIA3.Gbp','GAIA3.G','GAIA3.Grp'],
                       'col':['phot_bp_mean_mag','phot_g_mean_mag','phot_rp_mean_mag'],
                       'err_col':['phot_bp_mean_flux_error','phot_g_mean_flux_error','phot_rp_mean_flux_error'],
                       'lambda_eff':[5035,5822,7619],
                       'f0':[4.08e-9,2.5e-9,1.27e-9],
                       'm0':[0,0,0],
                       'wd_f0':[4.08e-9,2.5e-9,1.27e-9], ## bedard doesn't give GAIA f0
                       'wd_lambda_eff':[5035,5822,7619]})

galex_zp_table = Table({'band': ['GALEX.FUV', 'GALEX.NUV'],
                       'col':['fuv_mag','nuv_mag'],
                       'err_col':['fuv_magerr','nuv_magerr'],
                       'lambda_eff':[1548,2303],
                       'f0':[4.6e-8,2.05e-8],
                       'm0':[0,0],
                       'wd_f0':[4.6e-8,2.05e-8], ## bedard doesn't give GALEX f0- its AB anyway
                       'wd_lambda_eff':[1548,2303]})

synt_zp_table = Table({'band': ['Johnson.U', 'Johnson.B','Johnson.V','Johnson.R','Johnson.I',
                             'SDSS.u','SDSS.g','SDSS.r','SDSS.i','SDSS.z'],
                     'lambda_eff':[3551.05,4369.53,5467.57,6695.83,8568.89,3608.04,4671.78,6141.12,7457.89,8922.78],
                    'col':['u_jkc_flux','b_jkc_flux','v_jkc_flux','r_jkc_flux','i_jkc_flux',
                           'u_sdss_flux','g_sdss_flux','r_sdss_flux','i_sdss_flux','z_sdss_flux'],
                  'f0':[3.49719e-9,6.72553e-9,3.5833e-9,1.87529e-9,9.23651e-10,
                        3.75079e-9,5.45476e-9,2.49767e-9,1.38589e-9,8.38585e-10],
                        'm0':[0,0,0,0,0,0,0,0,0,0],
                        'wd_f0':[3.684e-9,6.548e-9,3.804e-9,2.274e-9,1.119e-9,
                                 1.1436e-8,4.9894e-9,2.8638e-9,1.9216e-9,1.3343e-9],
                        'wd_lambda_eff':[3971,4491,5423,6441,8071,
                                         3146,4670,6156,7471,8918]}) ## bedard f0 different than VOSA

bands_table = Table({'band': list(gaia_zp_table['band'])+list(wise_zp_table['band'])+list(twomass_zp_table['band'])
                     +list(synt_zp_table['band']) + list(galex_zp_table['band']),
                     'lambda_eff': list(gaia_zp_table['lambda_eff'])+list(wise_zp_table['lambda_eff'])+list(twomass_zp_table['lambda_eff'])
                     +list(synt_zp_table['lambda_eff']) + list(galex_zp_table['lambda_eff']),
                     'f0': list(gaia_zp_table['f0'])+list(wise_zp_table['f0'])+list(twomass_zp_table['f0'])
                     + list(synt_zp_table['lambda_eff']) + list(galex_zp_table['f0']),
                     'm0': list(gaia_zp_table['m0'])+list(wise_zp_table['m0'])+list(twomass_zp_table['m0']) + list(synt_zp_table['m0']) + list(galex_zp_table['m0']),
                     'wd_f0': list(gaia_zp_table['wd_f0'])+list(wise_zp_table['wd_f0'])+list(twomass_zp_table['wd_f0']) + list(synt_zp_table['wd_f0']) + list(galex_zp_table['wd_f0']),
                     'wd_lambda_eff': list(gaia_zp_table['wd_lambda_eff'])+list(wise_zp_table['wd_lambda_eff'])+list(twomass_zp_table['wd_lambda_eff']) + list(synt_zp_table['wd_lambda_eff']) + list(galex_zp_table['wd_lambda_eff'])})

Table.sort(bands_table, 'lambda_eff')

bands_table.add_column(['FUV','NUV','U','u','B','g','G3_BP','V','G3','r','R','i','G3_RP','I','z','J','H','Ks','W1','W2','W3'],name='wd_band')

In [4]:
## get photometry

def get_synthetic_photometry(source_table):
    ## query gaia synthetic photometry for SDSS ugriz and Johnson UBVRI, convert to erg/s/cm^2/AA
    id_lst = tuple(source_table['source_id'])
    cols_to_query =','.join(['source_id'] + [c + ',' + c +'_error' for c in synt_zp_table['col']])
    query = 'SELECT ' + cols_to_query + f' FROM gaiadr3.synthetic_photometry_gspc WHERE source_id IN {id_lst}'
    job = Gaia.launch_job(query)
    result = job.get_results()
    for col in synt_zp_table['col']:
        band = synt_zp_table[synt_zp_table['col']==col]['band'][0]
        err_col = col + '_error'
        if band.startswith('Johnson'): ## Johnson fluxes given in W/s/m^2/nm
            result[col] = result[col].to(u.erg/u.s/u.cm**2/u.AA)
            result[err_col] = result[err_col].to(u.erg/u.s/u.cm**2/u.AA)
        elif band.startswith('SDSS'): ## SDSS fluxes given in W/s/m^2/Hz
            wl = synt_zp_table[synt_zp_table['col']==col]['lambda_eff'][0] * u.AA
            result[col] = (result[col].data * u.W / u.m**2 / u.Hz * c / wl.si**2).to(u.erg/u.s/u.cm**2/u.AA)
            result[err_col] = (result[err_col].data * u.W / u.m**2 / u.Hz * c / wl.si**2).to(u.erg/u.s/u.cm**2/u.AA)
        result[col].name = band
        result[err_col].name = band +'_err'
    if 'SOURCE_ID' in result.colnames:
        result.rename_column('SOURCE_ID','source_id')
    return result

def get_gaia_photometry(source_table):
    id_lst = tuple(source_table['source_id'])
    cols_to_query =','.join(['source_id'] + [c for c in gaia_zp_table['col']] + [c.replace('mag','flux') + ',' + c.replace('mag','flux') +'_error' for c in gaia_zp_table['col']])
    query = 'SELECT ' + cols_to_query + f' FROM gaiadr3.gaia_source WHERE source_id IN {id_lst}'
    job = Gaia.launch_job(query)
    result = job.get_results()
    for col in gaia_zp_table['col']:
        band = gaia_zp_table[gaia_zp_table['col']==col]['band'][0]
        err_col = gaia_zp_table[gaia_zp_table['col']==col]['err_col'][0]
        f0 = gaia_zp_table[gaia_zp_table['col']==col]['f0'][0]
        m0 = gaia_zp_table[gaia_zp_table['col']==col]['m0'][0]
        err_over_f = result[col.replace('mag','flux')+'_error'] / result[col.replace('mag','flux')] ## fractional error- the fluxes are in arbitrary units, but the fractional error is the same as in erg/s/cm^2/AA
        mag = result[col]
        result[col] = f0 * 10**(-0.4*(mag-m0))
        result[col].unit = u.erg / u.s / u.cm**2 / u.AA
        result[col].name = band
        result[err_col] = result[band] * err_over_f ## to convert fractional error to absolute error in physical units
        result[err_col].unit = u.erg / u.s / u.cm**2 / u.AA
        result[err_col].name = band + '_err'
    if 'SOURCE_ID' in result.colnames:
        result.rename_column('SOURCE_ID','source_id')
    result.keep_columns(['source_id'] + list(gaia_zp_table['band']) + [band + '_err' for band in gaia_zp_table['band']])
    return result

def get_wise_photometry(source_table):
    coords = SkyCoord(source_table['ra'], source_table['dec'], unit=(u.deg, u.deg), frame='icrs')
    wise = Vizier.query_region(coords, radius=2*u.arcsec,catalog=['II/311'])
    if len(wise) == 0:
        wise = Table({'idx':source_table['idx'],'W1mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                      'W2mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                      'W3mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                      'e_W1mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                      'e_W2mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                      'e_W3mag':np.full_like(source_table['idx'], np.nan, dtype=float)})
    else:
        wise = wise[0]
        wise['_q'] = wise['_q'] - 1
        wise['_q'] = source_table['idx'][wise['_q']] 
        wise['_q'].name = 'idx'
    wise.keep_columns(['idx'] + list(wise_zp_table['col']) + list(wise_zp_table['err_col']))
    for col in wise_zp_table['col']:
        band = wise_zp_table[wise_zp_table['col']==col]['band'][0]
        err_col = wise_zp_table[wise_zp_table['col']==col]['err_col'][0]
        f0 = wise_zp_table[wise_zp_table['col']==col]['f0'][0]
        m0 = wise_zp_table[wise_zp_table['col']==col]['m0'][0]
        wise[col] = f0 * 10**(-0.4*(wise[col]-m0))
        wise[col].unit = u.erg / u.s / u.cm**2 / u.AA
        wise[col].name = band
        wise[err_col] = wise[band] * 0.4 * np.log(10) * wise[err_col] ## convert mag error to flux error
        wise[err_col].unit = u.erg / u.s / u.cm**2 / u.AA
        wise[err_col].name = band + '_err'
    for band in wise_zp_table['band']: ## mask out sources with low SNR (only keep sources WISE would flag A or B)
        for i in range(len(wise)):
            if (wise[i][band] / wise[i][band + '_err'] < 3) or np.ma.is_masked(wise[i][band+'_err']):
                wise[i][band] = np.nan
                wise[i][band + '_err'] = np.nan
    for idx in source_table['idx']:
        if idx not in wise['idx']:
            wise.add_row([idx] + [np.nan for col in wise_zp_table['col']] + [np.nan for err_col in wise_zp_table['err_col']])
    wise.sort('idx')
    return wise

def get_twomass_photometry(source_table):
    coords = SkyCoord(source_table['ra'], source_table['dec'], unit=(u.deg, u.deg), frame='icrs')
    twomass = Vizier.query_region(coords, radius=2*u.arcsec,catalog=['II/246'])
    if len(twomass) == 0:
        twomass = Table({'idx':source_table['idx'],'Jmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'Hmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'Kmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'e_Jmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'e_Hmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'e_Kmag':np.full_like(source_table['idx'], np.nan, dtype=float),
                         'Qflg':np.full_like(source_table['idx'], 'Z', dtype=str)})
    else:
        twomass = twomass[0]
        twomass['_q'] = twomass['_q'] - 1
        twomass['_q'] = source_table['idx'][twomass['_q']]
        twomass['_q'].name = 'idx'
        quality_flag = twomass['Qflg']
    twomass.keep_columns(['idx'] + list(twomass_zp_table['col']) + list(twomass_zp_table['err_col']))
    for col in twomass_zp_table['col']:
        band = twomass_zp_table[twomass_zp_table['col']==col]['band'][0]
        err_col = twomass_zp_table[twomass_zp_table['col']==col]['err_col'][0]
        f0 = twomass_zp_table[twomass_zp_table['col']==col]['f0'][0]
        m0 = twomass_zp_table[twomass_zp_table['col']==col]['m0'][0]
        twomass[col] = f0 * 10**(-0.4*(twomass[col]-m0))
        twomass[col].unit = u.erg / u.s / u.cm**2 / u.AA
        twomass[col].name = band
        twomass[err_col] = twomass[band] * 0.4 * np.log(10) * twomass[err_col] ## convert mag error to flux error
        twomass[err_col].unit = u.erg / u.s / u.cm**2 / u.AA
        twomass[err_col].name = band + '_err'
    for i in range(len(twomass)):
        for j in [0,1,2]: ## J,H,Ks bands respectively
            if quality_flag[i][j] != 'A': ## mask out sources according to 2MASS quality flag
                twomass[i][twomass_zp_table['band'][j]] = np.nan
                twomass[i][twomass_zp_table['band'][j] + '_err'] = np.nan
    for idx in source_table['idx']:
        if idx not in twomass['idx']:
            twomass.add_row([idx] + [np.nan for col in twomass_zp_table['col']] + [np.nan for err_col in twomass_zp_table['err_col']])
    twomass.sort('idx')
    return twomass

def get_galex_photometry(source_table):
    coords = SkyCoord(source_table['ra'], source_table['dec'], unit=(u.deg, u.deg), frame='icrs')
    galex = Table({key:np.full_like(source_table['idx'],np.nan,float) for key in list(galex_zp_table['col'])+list(galex_zp_table['err_col'])})
    galex['idx'] = source_table['idx']
    galex_dr7 = Vizier.query_region(coords, radius=2*u.arcsec,catalog=['II/335/galex_ais'])
    if len(galex_dr7) == 0:
        galex_dr7 = Table({'idx':source_table['idx'],'nuv_mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                           'nuv_magerr':np.full_like(source_table['idx'], np.nan, dtype=float),'fuv_mag':np.full_like(source_table['idx'], np.nan, dtype=float),
                           'fuv_magerr':np.full_like(source_table['idx'], np.nan, dtype=float)})
    else:
        galex_dr7 = galex_dr7[0]    
        galex_dr7['NUVmag'].name = 'nuv_mag'
        galex_dr7['e_NUVmag'].name = 'nuv_magerr'
        galex_dr7['FUVmag'].name = 'fuv_mag'
        galex_dr7['e_FUVmag'].name = 'fuv_magerr'
        galex_dr7['_q'] = galex_dr7['_q'] - 1
        galex_dr7['_q'] = source_table['idx'][galex_dr7['_q']]
        galex_dr7['_q'].name = 'idx'

    for i in range(len(coords)):
        data = Catalogs.query_region(coords[i], radius=2*u.arcsec,catalog='Galex')
        if len(data) == 0: ## no GALEX data found for this source
            continue 
        for col in galex_zp_table['col']:
            if np.ma.is_masked(data[col][0]): ## GALEX data is missing in this band 
                continue
            err_col = galex_zp_table[galex_zp_table['col']==col]['err_col'][0]
            f0 = galex_zp_table[galex_zp_table['col']==col]['f0'][0]
            m0 = galex_zp_table[galex_zp_table['col']==col]['m0'][0]
            galex[i][col] = f0 * 10**(-0.4*(data[col][0]-m0))
            galex[i][err_col] = galex[i][col] * 0.4 * np.log(10) * data[err_col][0]
            if galex[i]['idx'] in galex_dr7['idx']: ## if there is a match in the GALEX DR7 catalog, use that data instead
                j = np.where(galex_dr7['idx']==galex[i]['idx'])[0][0]
                if not np.ma.is_masked(galex_dr7[j][col]) and not np.isnan(galex_dr7[j][col]):
                    galex[i][col] = f0 * 10**(-0.4*(galex_dr7[j][col]-m0))
                    galex[i][err_col] = galex[i][col] * 0.4 * np.log(10) * galex_dr7[j][err_col]
            
    for col in galex_zp_table['col']:
        band = galex_zp_table[galex_zp_table['col']==col]['band'][0]
        err_col = galex_zp_table[galex_zp_table['col']==col]['err_col'][0]
        galex[col].unit = u.erg / u.s / u.cm**2 / u.AA
        galex[col].name = band
        galex[err_col].unit = u.erg / u.s / u.cm**2 / u.AA
        galex[err_col].name = band + '_err'
    return galex
        
def get_photometry(source_table):
    if not ('idx' in source_table.colnames):
        source_table['idx'] = np.arange(len(source_table))
    tbl = source_table['idx','source_id','ra','dec','parallax','parallax_error','[Fe/H]','Av']
    tbl = table.join(tbl, get_gaia_photometry(source_table), keys='source_id', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_synthetic_photometry(tbl), keys='source_id', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_galex_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_twomass_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_wise_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')

    flux_cols = list(gaia_zp_table['band']) + list(wise_zp_table['band']) + list(twomass_zp_table['band']) + list(synt_zp_table['band']) + list(galex_zp_table['band'])
    for col in flux_cols:
        snr = tbl[col] / tbl[col + '_err']
        for i in range(len(tbl)): 
            if snr[i] > 10: ## if SNR > 10, set error to 10% of flux: minimal error to account for model uncertainties
                tbl[i][col+'_err'] = 0.1 * tbl[i][col]

    Table.sort(tbl,'idx')
    return tbl

def get_photometry_single_source(source):
    
    query = f'''SELECT source_id, ra, dec, parallax, parallax_error, ag_gspphot FROM gaiadr3.gaia_source WHERE source_id = {source['source_id'][0]}'''
    job = Gaia.launch_job(query)
    result = job.get_results() 

    tbl = Table(source)
    
    if 'idx' not in source.colnames: tbl['idx'] = 0
    if 'ra' not in source.colnames: tbl['ra'] = result['ra'][0]
    if 'dec' not in source.colnames: tbl['dec'] = result['dec'][0]
    if 'parallax' not in source.colnames: tbl['parallax'] = result['parallax'][0]
    if 'parallax_error' not in source.colnames: tbl['parallax_error'] = result['parallax_error'][0]
    if 'Av' not in source.colnames: tbl['Av'] = result['ag_gspphot'][0]    

    tbl = tbl['idx','source_id','ra','dec','parallax','parallax_error','Av']
    tbl.add_row(tbl[0])
    tbl = table.join(tbl, get_gaia_photometry(tbl), keys='source_id', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_synthetic_photometry(tbl), keys='source_id', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_galex_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_twomass_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')
    tbl = table.join(tbl, get_wise_photometry(tbl), keys='idx', join_type='left',metadata_conflicts='silent')
    return Table(tbl[0])

In [5]:
## model and fitting functions

filter_list = ['Obs.Flux_2MASS.H', 'Obs.Flux_2MASS.J', 'Obs.Flux_2MASS.Ks',
                'Obs.Flux_ACS_WFC.F606W', 'Obs.Flux_ACS_WFC.F814W',
                  'Obs.Flux_APASS.B', 'Obs.Flux_APASS.V',
                    'Obs.Flux_DECam.Y', 'Obs.Flux_DECam.g', 'Obs.Flux_DECam.i', 'Obs.Flux_DECam.r', 'Obs.Flux_DECam.z',
                      'Obs.Flux_DENIS.I', 'Obs.Flux_DENIS.J', 'Obs.Flux_DENIS.Ks',
                        'Obs.Flux_GAIA3.G', 'Obs.Flux_GAIA3.Gbp', 'Obs.Flux_GAIA3.Grp', 'Obs.Flux_GAIA3.Grvs',
                          'Obs.Flux_GALEX.FUV', 'Obs.Flux_GALEX.NUV',
                            'Obs.Flux_Johnson.B', 'Obs.Flux_Johnson.I', 'Obs.Flux_Johnson.R', 'Obs.Flux_Johnson.U', 'Obs.Flux_Johnson.V',
                              'Obs.Flux_PS1.g', 'Obs.Flux_PS1.i', 'Obs.Flux_PS1.r', 'Obs.Flux_PS1.y', 'Obs.Flux_PS1.z',
                                'Obs.Flux_SDSS.g', 'Obs.Flux_SDSS.i', 'Obs.Flux_SDSS.r', 'Obs.Flux_SDSS.u', 'Obs.Flux_SDSS.z',
                                  'Obs.Flux_UKIDSS.K',
                                    'Obs.Flux_VISTA.H', 'Obs.Flux_VISTA.J', 'Obs.Flux_VISTA.Ks', 'Obs.Flux_VISTA.Y',
                                      'Obs.Flux_WISE.W1', 'Obs.Flux_WISE.W2', 'Obs.Flux_WISE.W3']

def get_distance(parallax,parallax_err):
    ## distance in meters calculated from parallax in mas
    return (1000/parallax) * pc.value , (1000/parallax**2)*parallax_err * pc.value

def blackbody_flux(wavelength, temperature):
    ## wavelength and temperature must have astropy units
    f = (2 * np.pi * h.cgs * c.cgs**2 / wavelength.cgs**5) / (np.exp(h.cgs * c.cgs / (wavelength.cgs * k_B.cgs * temperature)) - 1)
    return f.to(u.erg / u.s / u.cm**2 / u.AA)

def blackbody_mod_table(teff,radius,parallax,parallax_err,Av):
    bands = list(bands_table['band'])
    mod_tbl = Table(dict(zip(bands,[[float(1)] for i in range(len(bands))])))
    d, d_err = get_distance(parallax,parallax_err)
    radius = radius * R_sun.value

    AW1,AW2,AW3,AW4 = extinction.get_WISE_extinction(Av,0,teff)
    AJ,AH,AKs = extinction.get_2MASS_extinction(Av,0,teff)
    AU,AB,AV,AR,AI = extinction.get_Johnson_extinction(Av,0,teff)
    Au,Ag,Ar,Ai,Az = extinction.get_SDSS_extinction(Av,0,teff)
    AG,AGbp,AGrp = extinction.get_Gaia_extinction(Av,0,teff)
    AFUV,ANUV = extinction.get_Galex_extinction(Av,0,teff)

    ext_dict = {'2MASS.J':AJ,'2MASS.H':AH,'2MASS.Ks':AKs,'GALEX.FUV':AFUV,'GALEX.NUV':ANUV,'GAIA3.G':AG,'GAIA3.Gbp':AGbp,'GAIA3.Grp':AGrp,
                'WISE.W1':AW1,'WISE.W2':AW2,'WISE.W3':AW3,'WISE.W4':AW4,'Johnson.U':AU,'Johnson.B':AB,'Johnson.V':AV,'Johnson.R':AR,'Johnson.I':AI,
                'SDSS.u':Au,'SDSS.g':Ag,'SDSS.r':Ar,'SDSS.i':Ai,'SDSS.z':Az}

    for b in bands:
        filepath = os.path.join('..','data','VOSA','filters',b+'.dat')
        filter_tbl = Table.read(filepath, format='ascii', names=['wavelength','transmission'])
        wavelength = filter_tbl['wavelength'].data * u.AA
        
        transmission = filter_tbl['transmission'].data 
        flux = blackbody_flux(wavelength, teff * u.K)
        flux = np.dot(flux, transmission)
        flux /= np.sum(transmission)
        flux = flux * (radius/d)**2
        mod_tbl[0][b] = flux.value
        mod_tbl[b].unit = flux.unit 
        mod_tbl[b] = mod_tbl[b] * 10**(-0.4 * ext_dict[b])

    return mod_tbl
        
def get_chi2(obs_tbl_dered, mod_tbl, no_uv = False):
    chi2 = 0
    ndof = 0
    for col in mod_tbl.colnames:
        if no_uv and col.startswith('GALEX'):
           continue
        if col in obs_tbl_dered.colnames:
            if not(np.ma.is_masked(obs_tbl_dered[col]) or np.ma.is_masked(obs_tbl_dered[col + '_err'])):
                chi2 += np.sum((obs_tbl_dered[col] - mod_tbl[col])**2 /(obs_tbl_dered[col+'_err'])**2)
                ndof += 1
    return chi2/ndof

def plot_obs_sed(sources,idx,model=None,title='SED',no_uv = True, no_IR = False):
    warnings.simplefilter('ignore', UserWarning)
    obs_tbl = get_photometry(sources)
    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[idx][bnd] for bnd in bnds])
    flux_err = np.array([obs_tbl[idx][bnd + '_err'] for bnd in bnds])

    if model is not None:
         ## constructing and plotting model SED
        fig, ax = plt.subplots(figsize=(1.5*figwidth_single, figwidth_single), tight_layout=True,dpi=300)
        teff, radius, teff_err,radius_err, chi2 = model
        mod_tbl = blackbody_mod_table(teff,radius,obs_tbl[idx]['parallax'],obs_tbl[idx]['parallax_error'],obs_tbl[idx]['Av'])
        mod_flux = np.array([mod_tbl[bnd][0] for bnd in bnds])
        mod_flux = wl * mod_flux
        ax.plot(wl, mod_flux,color='Navy',ls='--',lw=.5,label=f'BB {int(teff)}K/{radius:.2f} $R_\odot$',marker='d',mec='RoyalBlue',mfc='RoyalBlue',markersize=2, zorder = 0)
        
        ## plotting fit results
        textstr = f'$\chi^2/ndof$= {chi2:.2f} \n' + '$T_{eff,1}=$' +f'{int(teff)}$\pm${int(teff_err)} K\n' + '$R_1=$' + f'{radius:.2f}$\pm${radius_err:.2f} $R_\odot$'
        props = dict(boxstyle='round', facecolor='Silver', alpha=0.9)
        ax.text(0.35, 0.5, textstr, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', bbox=props, color='k')
    else:
       fig,ax = plt.subplots(figsize=(figwidth_single, figwidth_single),dpi=300)
       
    ## plotting observed SED with error bars, mark UV and IR points if not used in fit
    flux = wl * flux
    ax.errorbar(wl,flux,yerr = wl * flux_err,fmt='.',label='observed',ecolor='k',elinewidth=1,mec='Crimson',mfc='Crimson',capsize=2,markersize=2, zorder = 5)
    ax.set_ylabel(r'$\lambda f_\lambda $($erg/s/cm^2/\AA$)',fontsize=8)  
    
    no_fit_cut = ((wl < 2500) * no_uv) | ((wl>10000) * bool(no_IR))
    if np.count_nonzero(~np.isnan(flux[no_fit_cut])) > 0:
        ax.scatter(wl[no_fit_cut],flux[no_fit_cut],marker='.',c='Gold',label='No fit',s=3,zorder = 10)
        
    ## aesthetics
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_title(title,fontsize=8)
    ax.set_xlabel(r'$\lambda$ $(\AA)$',fontsize=10)
    ax.set_ylabel(r'$\lambda f_\lambda$ ($erg/s/cm^2/\AA$)',fontsize=10)  
    ax.tick_params(axis='both', which='major', labelsize=8)
    ax.grid()
    ax.legend(fontsize=10,loc='lower center')
    fig.show()
    return fig

def fit_table(sources, init = (6000,1) ,teff_bounds = (3000,10000), r_bounds = (0.4,3.0), no_uv = True, no_IR = []):
    ## init, teff_bounds, r_bounds either a tuple or a list of tuples of length len(sources)
    ## if list of tuples, each tuple is for a different source
    obs_tbl = get_photometry(sources)
    # obs_tbl_dered = deredden_obs_table(obs_tbl)
    results = Table(np.zeros((len(sources),6)), names= ['teff1','teff1_err','r1','r1_err','chi2/ndof','no_IR'])
    
    if type(init) == tuple:
      init = [init for i in range(len(sources))]
    elif len(init) != len(sources):
      raise ValueError('init must be a tuple or a list of tuples of length len(sources)')
    if type(teff_bounds) == tuple:
      teff_bounds = [teff_bounds for i in range(len(sources))]
    elif len(teff_bounds) != len(sources):
      raise ValueError('teff_bounds must be a tuple or a list of tuples of length len(sources)')
    if type(r_bounds) == tuple:
      r_bounds = [r_bounds for i in range(len(sources))]
    elif len(r_bounds) != len(sources):
      raise ValueError('r_bounds must be a tuple or a list of tuples of length len(sources)')
    
    i = 0
    for obs_row,init,teff_bounds,r_bounds in tqdm(zip(obs_tbl,init,teff_bounds,r_bounds),total=len(sources)):
      idx = sources['idx'][i]
      bnds = list(bands_table['band'])
      if no_uv:
        bnds = [b for b in bnds if not b.startswith('GALEX')]
      if idx in no_IR:
         bnds = [b for b in bnds if not (b.startswith('2MASS') or b.startswith('WISE'))]
      bnds = [b for b in bnds if not (np.ma.is_masked(obs_row[b]) or np.isnan(obs_row[b]))] 
      wl = np.array([bands_table[bands_table['band']==bnd]['lambda_eff'][0] for bnd in bnds])
      
      def bb_fun(wl,teff,radius):
        mod_tbl = blackbody_mod_table(teff,radius,obs_row['parallax'],obs_row['parallax_error'],obs_row['Av'])
        mod_flux = np.array([mod_tbl[bnd][0] for bnd in bnds])
        return mod_flux

      flux = np.array([obs_row[bnd] for bnd in bnds])
      flux_err = np.array([obs_row[bnd + '_err'] for bnd in bnds])
      vals, covm = curve_fit(bb_fun, wl, flux, p0 = list(init), bounds = ([teff_bounds[0],r_bounds[0]],[teff_bounds[1],r_bounds[1]]), sigma = flux_err, absolute_sigma = True)
      results[i]['teff1'] = vals[0]
      results[i]['teff1_err'] = covm[0,0]**(1/2)
      results[i]['r1'] = vals[1]
      results[i]['r1_err'] = covm[1,1]**(1/2)
      results[i]['chi2/ndof'] = np.sum((flux - bb_fun(wl,vals[0],vals[1]))**2 / flux_err**2) / (len(flux) - len(vals))
      results[i]['no_IR'] = idx in no_IR
      i += 1
    return results

In [6]:
## fitting the table

# sources = Table.read('../table_C.fits', format='fits')
# init_lst = [(6000,sources[i]['m1']) for i in range(len(sources))]
# r_bounds_lst = [(0.3*sources[i]['m1'],3*sources[i]['m1']) for i in range(len(sources))]
# teff_bounds_lst = [(3000,20000) for i in range(len(sources))]
# no_IR = []
# results = fit_table(sources, init = init_lst ,teff_bounds = teff_bounds_lst, r_bounds = r_bounds_lst, no_uv = True , no_IR = no_IR)


# # saving the results

# sources['teff1'] = results['teff1']
# sources['teff1_err'] = results['teff1_err']
# sources['r1'] = results['r1']
# sources['r1_err'] = results['r1_err']
# sources['chi2/ndof'] = results['chi2/ndof']
# sources['no_IR'] = results['no_IR']

# fail = []
# # sources.write('../table_C.fits', format='fits', overwrite=True)

# # plotting the results

# for j in j_mv_lbl:
#     fig = plot_obs_sed(sources,j,model=sources[j]['teff1','r1','teff1_err','r1_err','chi2/ndof'],title=f'Candidate {sources[j]["idx"]} SED fitting', no_uv = True, no_IR = sources[j]['no_IR'])
#     fig.savefig(f'../img/bb_fits/bb_{sources[j]["idx"]}.png',bbox_inches='tight')
#     plt.close(fig)

# SED models and UV excess

In [7]:
## SED model functions


filter_band_dict = {'GALEX/GALEX.FUV':'GALEX.FUV','GALEX/GALEX.NUV':'GALEX.NUV','GAIA/GAIA3.G':'GAIA3.G','GAIA/GAIA3.Gbp':'GAIA3.Gbp','GAIA/GAIA3.Grp':'GAIA3.Grp',
    'Generic/Johnson.U':'Johnson.U','Generic/Johnson.B':'Johnson.B','Generic/Johnson.V':'Johnson.V','Generic/Johnson.R':'Johnson.R','Generic/Johnson.I':'Johnson.I'
        ,'SLOAN/SDSS.u':'SDSS.u','SLOAN/SDSS.g':'SDSS.g','SLOAN/SDSS.r':'SDSS.r','SLOAN/SDSS.i':'SDSS.i','SLOAN/SDSS.z':'SDSS.z',
        '2MASS/2MASS.J':'2MASS.J','2MASS/2MASS.H':'2MASS.H','2MASS/2MASS.Ks':'2MASS.Ks',
        'WISE/WISE.W1':'WISE.W1','WISE/WISE.W2':'WISE.W2','WISE/WISE.W3':'WISE.W3'}


def get_ms_sed(teff,m1,r1,meta,parallax):
    kur = Table(np.genfromtxt('../data/kurucz_sed.dat',names=True,dtype=None))
    model_table = Table(data=[[np.nan] for band in bands_table['wd_band']],names=bands_table['wd_band'])
    logg = np.log10((G.cgs * m1 * M_sun.cgs / (r1 *R_sun.cgs)**2).value)
    for band in bands_table['wd_band']:
        interp = LinearNDInterpolator(np.array([kur['teff'],kur['logg'],kur['meta']]).T,kur[band])
        model_table[band] = interp([teff,logg,meta])[0] * (r1 * R_sun.cgs * parallax / 1000 / pc.cgs)**2
        model_table[band].unit = u.erg / u.s / u.cm**2 / u.AA
    return model_table

def logg_from_MR_relation(teff,m,core='CO',atm='H'):
    ## teff in K, m in solar masses
    ## returns log_g in cm/s^2
    if core == 'He':
        wd = np.genfromtxt('../data/WD_models/He_wd.dat',names=True,dtype=None)
    elif core == 'CO':
        if atm == 'H':
            wd = np.genfromtxt('../data/WD_models/CO_DA.dat',names=True,dtype=None)
        elif atm == 'He':
            wd = np.genfromtxt('../data/WD_models/CO_DB.dat',names=True,dtype=None)
        else:
            print('atm either H or He')
            return None
    elif core == 'ONe':
        if atm == 'H':
            wd = np.genfromtxt('../data/WD_models/ONe_DA.dat',names=True,dtype=None)
        elif atm == 'He':
            wd = np.genfromtxt('../data/WD_models/ONe_DB.dat',names=True,dtype=None)
        else:
            print('atm either H or He')
            return None
    interp = LinearNDInterpolator(np.array([wd['Teff'],wd['Mass']]).T,wd['log_g'])
    logg = interp([teff,m])[0]
    return logg

def get_wd_sed(teff,m,parallax,core='CO',atm='H'):
    ## use MR relation specific to core and atmosphere type
    logg = logg_from_MR_relation(teff,m,core,atm) 
    
    ## then interpolate Teff, Logg on the CO model to get SED
    if atm == 'H':
        model = np.genfromtxt('../data/WD_models/CO_DA.dat',names=True,dtype=None)
    elif atm == 'He':
        model = np.genfromtxt('../data/WD_models/CO_DB.dat',names=True,dtype=None)
    else:
        print('atm either H or He')
        return None
    
    model_table = Table(data=[[np.nan] for band in bands_table['wd_band']],names=bands_table['wd_band'])
    for wd_band,f0,m0 in zip(bands_table['wd_band'],bands_table['f0'],bands_table['m0']): 
        interp = LinearNDInterpolator(np.array([model['Teff'],model['log_g']]).T,model[wd_band])
        absmag = interp([teff,logg])[0]
        apmag = absmag + 5 * np.log10(1000/parallax) - 5
        model_table[wd_band] = f0 * 10**(-0.4*(apmag - m0))
        model_table[wd_band].unit = u.erg / u.s / u.cm**2 / u.AA
    return model_table

def get_wd_koester_sed(teff,m,parallax,core='CO',atm='H'):
    ## use MR relation specific to core and atmosphere type
    logg = logg_from_MR_relation(teff,m,core,atm)
    r = np.sqrt(G.cgs * m * M_sun.cgs / (10**logg) / (u.cm/u.s**2))/R_sun.cgs ## in R_sun

    ## then interpolate Teff, Logg on the koester model to get SED
    model = np.genfromtxt('../data/koester_sed.dat',names=True,dtype=None)
    model_table = Table(data=[[np.nan] for band in bands_table['wd_band']],names=bands_table['wd_band'])
    for wd_band in bands_table['wd_band']:
        interp = LinearNDInterpolator(np.array([model['teff'],model['logg']]).T,model[wd_band])
        model_table[wd_band] = interp([teff,logg])[0] * (r * R_sun.cgs * parallax / 1000 / pc.cgs)**2
        model_table[wd_band].unit = u.erg / u.s / u.cm**2 / u.AA
    return model_table

def redden_model_table(mod_tbl,teff,av):
    AW1,AW2,AW3,AW4 = extinction.get_WISE_extinction(av,0,teff)
    AJ,AH,AKs = extinction.get_2MASS_extinction(av,0,teff)
    AU,AB,AV,AR,AI = extinction.get_Johnson_extinction(av,0,teff)
    Au,Ag,Ar,Ai,Az = extinction.get_SDSS_extinction(av,0,teff)
    AG,AGbp,AGrp = extinction.get_Gaia_extinction(av,0,teff)
    AFUV,ANUV = extinction.get_Galex_extinction(av,0,teff)
    ext_dict = {'2MASS.J':AJ,'2MASS.H':AH,'2MASS.Ks':AKs,'GALEX.FUV':AFUV,'GALEX.NUV':ANUV,'GAIA3.G':AG,'GAIA3.Gbp':AGbp,'GAIA3.Grp':AGrp,
                'WISE.W1':AW1,'WISE.W2':AW2,'WISE.W3':AW3,'WISE.W4':AW4,'Johnson.U':AU,'Johnson.B':AB,'Johnson.V':AV,'Johnson.R':AR,'Johnson.I':AI,
                'SDSS.u':Au,'SDSS.g':Ag,'SDSS.r':Ar,'SDSS.i':Ai,'SDSS.z':Az}
    for band,col in zip(bands_table['band'],bands_table['wd_band']):
        mod_tbl[col][0] = mod_tbl[col][0] * 10**(-0.4 * ext_dict[band])
        
    return mod_tbl

def get_cooling_temp(age,m,core='CO',atm='H'):
    if core =='He':
        model = np.genfromtxt('../data/WD_models/He_wd.dat',names=True,dtype=None)
    elif core == 'CO':
        if atm == 'H':
            model = np.genfromtxt('../data/WD_models/CO_DA.dat',names=True,dtype=None)
        elif atm == 'He':
            model = np.genfromtxt('../data/WD_models/CO_DB.dat',names=True,dtype=None)
        else:
            print('atm either H or He')
            return None
    elif core == 'ONe':
        if atm == 'H':
            model = np.genfromtxt('../data/WD_models/ONe_DA.dat',names=True,dtype=None)
        elif atm == 'He':
            model = np.genfromtxt('../data/WD_models/ONe_DB.dat',names=True,dtype=None)
        else:
            print('atm either H or He')
            return None
    interp = LinearNDInterpolator(np.array([model['Age'],model['Mass']]).T,model['Teff'])
    teff = int(interp([age,m])[0])
    return teff

def plot_kurucz_vs_obs(sources,idx,plot=True,save=False):
    warnings.simplefilter('ignore', UserWarning)
    i = np.where(sources['idx'] == idx)[0][0]
    obs_tbl = get_photometry_single_source(Table(sources[i]))

    bnds = list(bands_table['band'])
    wl = np.array(list(bands_table['lambda_eff']))
    flux = np.array([obs_tbl[0][bnd] for bnd in bnds])
    flux = flux * wl
    flux_err = np.array([obs_tbl[0][bnd + '_err'] for bnd in bnds])
    flux_err = flux_err * wl

    teff1 = sources[i]['teff1']
    teff_lo = 0.9 * teff1
    teff_hi = 1.1 * teff1
    r1 = sources[i]['r1']
    r_lo = 0.9 * r1
    r_hi = 1.1 * r1
    m = sources[i]['m1']
    m_err = sources[i]['m1_err']
    parallax = sources[i]['parallax']
    meta = sources[i]['[Fe/H]']
    av = sources[i]['Av']
    av_err = sources[i]['e_Av']
    ms_lo = get_ms_sed(teff_lo, m + m_err, r_lo, meta, parallax)
    ms_hi = get_ms_sed(teff_hi, m - m_err, r_hi, meta, parallax)
    ms_lo = redden_model_table(ms_lo,teff_lo,av+av_err)
    ms_hi = redden_model_table(ms_hi,teff_hi,max([av-av_err,0]))
    ms = get_ms_sed(teff1, m, r1, meta, parallax)
    ms = redden_model_table(ms,teff1,av)

    flux_lo = np.array([ms_lo[bnd][0] for bnd in bands_table['wd_band']]) * wl
    flux_hi = np.array([ms_hi[bnd][0] for bnd in bands_table['wd_band']])* wl
    flux_bst = np.array([ms[bnd][0] for bnd in bands_table['wd_band']]) * wl

    fig, ax = plt.subplots(figsize=(figwidth_single, 0.8*figwidth_single), tight_layout=True, dpi=300)
    ax.plot(wl,flux_bst,color='RoyalBlue',ls='--',lw=0.5)
    ax.fill_between(wl,flux_lo,flux_hi,color='RoyalBlue',alpha=0.5,label=f'Best fit Kurucz model + errors')
   
    ax.set_ylabel(r'$\lambda f_\lambda $($erg/s/cm^2/\AA$)',fontsize=8)
    ax.set_xlabel(r'wavelength $(\AA)$',fontsize=8)
    ax.errorbar(wl, flux,yerr = flux_err,fmt='.',label='observed',ecolor='k',elinewidth=2,mec='Crimson',mfc='Crimson',capsize=2,markersize=1)

    if sources[i]['no_IR'] and len(flux[wl>10000]) > 0:
        ax.scatter(wl[wl>10000],flux[wl>10000],marker='.',color='Gold',s=1,zorder = 10)
    ax.scatter(wl[wl<3500],flux[wl<3500],marker='.',color='Gold',s=1,zorder = 10)

    ax.set_ylabel(r'$\lambda f_\lambda $($erg/s/cm^2/\AA$)',fontsize=8)

    from matplotlib.ticker import FormatStrFormatter
    ax.set_xticks([2e3,5e3,8e3,1e4,2e4,5e4,1e5])
    ax.get_xaxis().set_major_formatter(FormatStrFormatter('%.0e'))
    ax.set_xlim(1e3,1.2*wl[~np.isnan(flux)].max())
    ax.set_xscale('log')
    ax.tick_params(axis='x', which='major', labelsize=7)

    ax.set_yscale('log')
    ax.tick_params(axis='y', which='major', labelsize=8)
    
    ax.legend(fontsize=8,loc='lower right',frameon=False)
    ax.set_title(f'candidate {idx}',fontsize=7)
    fig.show()
    if save:
        fig.savefig(f'../img/kurucz/kurucz_{idx}.png')
    if not plot:
        plt.close(fig)


In [70]:
## UV excess functions

def obs_nuv_flux(mag):
    f0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.NUV']['f0'].data[0]
    m0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.NUV']['m0'].data[0]
    f = 10**(-0.4 * (mag - m0)) * f0
    return f

def obs_fuv_flux(mag):
    f0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.FUV']['f0'].data[0]
    m0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.FUV']['m0'].data[0]
    f = 10**(-0.4 * (mag - m0)) * f0
    return f

def nuv_flux_to_mag(flux):
    f0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.NUV']['f0'].data[0]
    m0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.NUV']['m0'].data[0]
    return m0 - 2.5 * np.log10(flux / f0)

def fuv_flux_to_mag(flux):
    f0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.FUV']['f0'].data[0]
    m0 = galex_zp_table[galex_zp_table['band'] == 'GALEX.FUV']['m0'].data[0]
    return m0 - 2.5 * np.log10(flux / f0)

def get_primary_nuv_lims(t_lo,t,t_hi,m_lo,m,m_hi,r_lo,r,r_hi,parallax,meta,av):
    ## get the NUV flux range of the primary star
    ms_hi = get_ms_sed(t_hi,m_hi,r_hi,meta,parallax)
    ms = get_ms_sed(t,m,r,meta,parallax)
    ms_lo = get_ms_sed(t_lo,m_lo,r_lo,meta,parallax)
    _, ANUV_lo = extinction.get_Galex_extinction(av,0,t_lo)
    _, ANUV = extinction.get_Galex_extinction(av,0,t)
    _, ANUV_hi = extinction.get_Galex_extinction(av,0,t_hi)
    nuv_hi = ms_hi['NUV'][0]
    nuv = ms['NUV'][0]
    nuv_lo = ms_lo['NUV'][0]
    nuv_hi = nuv_hi * 10**(-0.4 * ANUV_hi)
    nuv = nuv * 10**(-0.4 * ANUV)
    nuv_lo = nuv_lo * 10**(-0.4 * ANUV_lo)
    return nuv_lo, nuv, nuv_hi

def get_primary_fuv_lims(t_lo,t,t_hi,m_lo,m,m_hi,r_lo,r,r_hi,parallax,meta,av):
    ms_hi = get_ms_sed(t_hi,m_hi,r_hi,meta,parallax)
    ms = get_ms_sed(t,m,r,meta,parallax)
    ms_lo = get_ms_sed(t_lo,m_lo,r_lo,meta,parallax)
    fuv_hi = ms_hi['FUV'][0]
    fuv = ms['FUV'][0]
    fuv_lo = ms_lo['FUV'][0]
    AFUV_lo,_ = extinction.get_Galex_extinction(av,0,t_lo)
    AFUV,_ = extinction.get_Galex_extinction(av,0,t)
    AFUV_hi,_ = extinction.get_Galex_extinction(av,0,t_hi)
    fuv_hi = fuv_hi * 10**(-0.4 * AFUV_hi)
    fuv = fuv * 10**(-0.4 * AFUV)
    fuv_lo = fuv_lo * 10**(-0.4 * AFUV_lo)
    return fuv_lo, fuv, fuv_hi

def get_companion_RT_lims_fuv(obs_mag,fuv_lo,fuv,fuv_hi,m2,parallax,av,core='CO',atm='H'):
    if core == 'He':
        wd_path = '../data/WD_models/He_wd.dat'
        wd_teff_grid = np.linspace(5000,15000,40)
    elif core == 'CO': 
        # wd_teff_grid = np.linspace(3000,50000,30)
        wd_teff_grid = np.linspace(50000,150000,30)
        if atm == 'H': wd_path = '../data/WD_models/CO_DA.dat'
        elif atm == 'He': wd_path = '../data/WD_models/CO_DB.dat'
    elif core == 'ONe':
        # wd_teff_grid = np.linspace(3000,50000,30)
        wd_teff_grid = np.linspace(50000,150000,30)
        if atm == 'H': wd_path = '../data/WD_models/ONe_DA.dat'
        elif atm == 'He': wd_path = '../data/WD_models/ONe_DB.dat'

    def wd_reddened_fuv(t,m2,parallax,av,core,atm): ## auxiliary: for solving Teff2
        wd = get_wd_sed(t,m2,parallax,core,atm)
        fuv = wd['FUV'][0]
        AFUV,_ = extinction.get_Galex_extinction(av,0,t)
        fuv = fuv * 10**(-0.4 * AFUV)
        return fuv

    # model0 = np.genfromtxt(wd_path,names=True,dtype=None)
    
    logg_grid = np.array([logg_from_MR_relation(t,m2,core,atm) for t in wd_teff_grid])
    cut = (logg_grid >=7) & (logg_grid <= 9) & ~(np.isnan(logg_grid))
    wd_teff_grid = wd_teff_grid[cut]

    obs_flux = obs_fuv_flux(obs_mag)
    excess_hi = obs_flux - fuv_lo
    excess_lo = obs_flux - fuv_hi
    excess = obs_flux - fuv

    # try:
    #     t = bisect(lambda t: wd_reddened_fuv(t,m2,parallax,av,core,atm)/excess - 1,wd_teff_grid.min(),wd_teff_grid.max(),xtol = 300)
    # except: 
    #     t = 1 * np.sign((wd_reddened_fuv(wd_teff_grid.min(),m2,parallax,av,core,atm) - excess)) ## +1 if not enough observed excess, -1 if too much observed excess 
    try:
        fvec = np.array([wd_reddened_fuv(t,m2,parallax,av,core,atm)/excess -1 for t in wd_teff_grid])
        wd_teff_grid = wd_teff_grid[~np.isnan(fvec)]
        fvec = fvec[~np.isnan(fvec)]
        t = interp1d(fvec,wd_teff_grid,kind='linear')(0)
    except:
        t = 1 * np.sign((wd_reddened_fuv(wd_teff_grid.min(),m2,parallax,av,core,atm) - excess))
    return t

def get_companion_RT_lims_nuv(obs_mag,nuv_lo,nuv,nuv_hi,m2,parallax,av,core='CO',atm='H'):
    if core == 'He':
        wd_path = '../data/WD_models/He_wd.dat'
        wd_teff_grid = np.linspace(5000,15000,40)
    elif core == 'CO': 
        # wd_teff_grid = np.linspace(3000,50000,30)
        wd_teff_grid = np.linspace(50000,150000,30)
        if atm == 'H': wd_path = '../data/WD_models/CO_DA.dat'
        elif atm == 'He': wd_path = '../data/WD_models/CO_DB.dat'
    elif core == 'ONe':
        # wd_teff_grid = np.linspace(3000,50000,30)
        wd_teff_grid = np.linspace(50000,150000,30)
        if atm == 'H': wd_path = '../data/WD_models/ONe_DA.dat'
        elif atm == 'He': wd_path = '../data/WD_models/ONe_DB.dat'
    
    def wd_reddened_nuv(t,m2,parallax,av,core,atm): ## auxiliary: for solving Teff2
        wd = get_wd_sed(t,m2,parallax,core,atm)
        nuv = wd['NUV'][0]
        _,ANUV = extinction.get_Galex_extinction(av,0,t)
        nuv = nuv * 10**(-0.4 * ANUV)
        return nuv

    # model0 = np.genfromtxt(wd_path,names=True,dtype=None)
    logg_grid = np.array([logg_from_MR_relation(t,m2,core,atm) for t in wd_teff_grid])
    cut = (logg_grid >=7) & (logg_grid <= 9) & ~(np.isnan(logg_grid))
    wd_teff_grid = wd_teff_grid[cut]
    
    obs_flux = obs_nuv_flux(obs_mag)
    excess_hi = obs_flux - nuv_lo
    excess_lo = obs_flux - nuv_hi
    excess = obs_flux - nuv
    # try:
    #     t = bisect(lambda t: wd_reddened_nuv(t,m2,parallax,av,core,atm)/excess - 1,wd_teff_grid.min(),wd_teff_grid.max(),xtol = 300)
    # except: 
    #     t = 1 * np.sign((wd_reddened_nuv(wd_teff_grid.min(),m2,parallax,av,core,atm) - excess)) ## +1 if not enough observed excess, -1 if too much observed excess
    try:
        fvec = np.array([wd_reddened_nuv(t,m2,parallax,av,core,atm)/excess -1 for t in wd_teff_grid])
        wd_teff_grid = wd_teff_grid[~np.isnan(fvec)]
        fvec = fvec[~np.isnan(fvec)]
        t = interp1d(fvec,wd_teff_grid,kind='linear')(0)
    except:
        t = 1 * np.sign((wd_reddened_nuv(wd_teff_grid.min(),m2,parallax,av,core,atm) - excess))
    return t

def assess_uv_excess(sources,idx):
    i = np.where(sources['idx'] == idx)[0][0]
    t_lo,t,t_hi = sources[i]['teff1'] * 0.9,sources[i]['teff1'] ,sources[i]['teff1']*1.1
    m_lo,m,m_hi = sources[i]['m1'] - sources[i]['m1_err'], sources[i]['m1'], sources[i]['m1'] + sources[i]['m1_err']
    r_lo,r,r_hi = sources[i]['r1'] * 0.9, sources[i]['r1'], sources[i]['r1'] * 1.1
    parallax = sources[i]['parallax']
    meta = sources[i]['[Fe/H]']
    age = sources[i]['age']
    av = sources[i]['Av']
    m2 = sources[i]['m2']
    
    nuv_mag = sources[i]['nuv_mag']
    fuv_mag = sources[i]['fuv_mag']

    tmin_H = get_cooling_temp(age,m2,atm='H')
    tmin_He = get_cooling_temp(age,m2,atm='He')

    if m2 < 0.45:
        core = 'He'
    elif m2 < 1.1:
        core = 'CO'
    else:
        core = 'ONe'

    if np.ma.is_masked(nuv_mag):
        # print(f'no NUV data for {idx}')
        ''
    else:
        nuv_lo,nuv,nuv_hi = get_primary_nuv_lims(t_lo,t,t_hi,m_lo,m,m_hi,r_lo,r,r_hi,parallax,meta,av)
        nuv_H_tarr= get_companion_RT_lims_nuv(nuv_mag,nuv_lo,nuv,nuv_hi,m2,parallax,av,core=core,atm='H')
        nuv_He_tarr = get_companion_RT_lims_nuv(nuv_mag,nuv_lo,nuv,nuv_hi,m2,parallax,av,core=core,atm='He')
        print(f'{core} WD {idx} NUV: {nuv_H_tarr:.0f} K (H), {nuv_He_tarr:.0f} K (He)')
    if np.ma.is_masked(fuv_mag):
        # print(f'no FUV data for {idx}')
        ''
    else:
        fuv_lo,fuv,fuv_hi = get_primary_fuv_lims(t_lo,t,t_hi,m_lo,m,m_hi,r_lo,r,r_hi,parallax,meta,av)
        fuv_H_tarr = get_companion_RT_lims_fuv(fuv_mag,fuv_lo,fuv,fuv_hi,m2,parallax,av,core=core,atm='H')
        fuv_He_tarr = get_companion_RT_lims_fuv(fuv_mag,fuv_lo,fuv,fuv_hi,m2,parallax,av,core=core,atm='He')
        print(f'{core} WD {idx} FUV: {fuv_H_tarr:.0f} K (H), {fuv_He_tarr:.0f} K (He)')

# Binary SED fit

In [None]:
import emcee


def fit_source(sources,idx,teff2_arr,core,atm):
    ## teff2_arr: [t_lo,t_guess,t_hi]
    i = np.where(sources['idx'] == idx)[0][0]

    parallax_guess = sources[i]['parallax']
    parallax_err = sources[i]['parallax_error']
    meta_guess = sources[i]['[Fe/H]']
    meta_err = sources[i]['e_[Fe/H]']
    av_guess = sources[i]['Av']
    av_err = sources[i]['e_Av']

    teff1_guess = sources[i]['teff1']
    r1_guess = sources[i]['r1']
    m1_guess = sources[i]['m1']
    m1_err = sources[i]['m1_err']
    m2_guess = sources[i]['m2']
    m2_err = sources[i]['m2_err']
    teff2_lo, teff2_guess, teff2_hi = teff2_arr
    

    def log_prior_av(av): ## Gaussian prior on Av
        return -0.5 * (av - av_guess)**2 / av_err**2
    def log_prior_parallax(parallax): ## Gaussian prior on parallax
        return -0.5 * (parallax - parallax_guess)**2 / parallax_err**2
    def log_prior_meta(meta): ## Gaussian prior on [Fe/H]
        return -0.5 * (meta - meta_guess)**2 / meta_err**2
    def log_prior_m1(m1): ## Gaussian prior on m1
        return -0.5 * (m1 - m1_guess)**2 / m1_err**2
    def log_prior_m2(m2): ## Gaussian prior on m2
        return -0.5 * (m2 - m2_guess)**2 / m2_err**2 
    def log_prior_r1(r1): ## Uniform prior on r1
        if 0.1 < r1 < 5:
            return 0
        return -np.inf
    def log_prior_teff1(teff1): ## Uniform prior on teff1
        if 3500 < teff1 < 15000:
            return 0
        return -np.inf
    def log_prior_teff2(teff2): ## Uniform prior on teff2
        if 5000 < teff2 < 80000:
            return 0
        return -np.inf
    
    def log_prior(params):
        teff1,teff2,r1,m1,m2,av,parallax,meta = params
        lp_teff1 = log_prior_teff1(teff1)
        lp_teff2 = log_prior_teff2(teff2)
        lp_r1 = log_prior_r1(r1)
        if not np.isfinite(lp_teff1 + lp_teff2 + lp_r1):
            return -np.inf
        lp_m1 = log_prior_m1(m1)
        lp_m2 = log_prior_m2(m2)
        lp_av = log_prior_av(av)
        lp_parallax = log_prior_parallax(parallax)
        lp_meta = log_prior_meta(meta)
        return lp_teff1 + lp_teff2 + lp_r1 + lp_m1 + lp_m2 + lp_av + lp_parallax + lp_meta
    
    def log_likelihood(params):
        obs_tbl = get_photometry_single_source(sources[i])
        m1,m2,r1,teff1,teff2 = params
        # ms = model_sed_ms(teff1,m1,m1_err)
        # wd = get_wd_sed(teff2,m2,core,atm)
        


# Misc- Swift SNR, latex tables

In [60]:
# zp = 7.5e-16
# print('idx cts/s texp')
# for i in range(len(sources)):
#     av = sources[i]['Av']
#     anuv = extinction.get_Galex_extinction(av,0,sources[i]['teff2_min'])[0]
#     m2 = sources[i]['m2']
#     age = sources[i]['age']
#     nuv = m_agecool_to_mag(m2,age*1e-3)
#     parallax = sources[i]['parallax']
#     nuv_apparent = nuv + 5 * np.log10(1000/parallax) - 5 + anuv
#     c = obs_nuv_flux(nuv_apparent,av)/zp
#     texp = 20**2 / c
#     if texp < 1000 and np.ma.is_masked(sources[i]['nuv_mag']):
#         print(sources[i]['idx'],c.round(2),int(texp))
    

idx cts/s texp
53 0.5 805
101 0.77 522
105 0.69 582
191 0.7 571
192 1.04 382
282 0.88 456


In [41]:
# sources = sources[sources['idx'] == 99]

# sources['phot_g_mean_mag'] = sources['phot_g_mean_mag'].round(1)
# sources['ra'] = sources['ra'].round(5)
# sources['dec'] = sources['dec'].round(5)
# sources['distance'] = (1000/sources['parallax']).round(0)
# sources['distance_err'] = (1000/sources['parallax']**2 * sources['parallax_error']).round(0)
# sources['D'] = [f'{d}±{d_err}' for d,d_err in zip(sources['distance'],sources['distance_err'])]
# sources['nuv_mag'] = sources['nuv_mag'].round(1)
# sources['m1'] = sources['m1'].round(2)
# sources['m1_err'] = sources['m1_err'].round(2)
# sources['m2'] = sources['m2'].round(2)
# sources['m2_err'] = sources['m2_err'].round(2)
# sources['r1'] = sources['r1'].round(2)
# sources['r1_err'] = sources['r1_err'].round(2)
# sources['teff1'] = np.array(sources['teff1'].round(-1),dtype = int)
# sources['teff1_err'] = np.array(sources['teff1_err'].round(-1),dtype = int)
# sources['teff2'] = np.array(sources['teff2'].round(-1),dtype = int)
# sources['age'] = (10**sources['log_age_50'] * 1e-6).round(0)
# sources['M1'] = [f'{m:.2f}±{m_err:.2f}' for m,m_err in zip(sources['m1'],sources['m1_err'])]
# sources['R1'] = [f'{r:.2f}±{r_err:.2f}' for r,r_err in zip(sources['r1'],sources['r1_err'])]
# sources['M2'] = [f'{m:.2f}±{m_err:.2f}' for m,m_err in zip(sources['m2'],sources['m2_err'])]
# sources['T1'] = [f'{t}±{t_err}' for t,t_err in zip(sources['teff1'],sources['teff1_err'])]
# sources['[Fe/H]'] = sources['[Fe/H]'].round(2)
# sources['Av'] = sources['Av'].round(2)
# sources['age'] = sources['age'].round(0)   
# print(sources['T1','R1','M2','teff2','[Fe/H]','Av','age'].to_pandas().to_latex(index=False))

\begin{tabular}{lllrrrr}
\toprule
T1 & R1 & M2 & teff2 & mh_for_mass_interp & av_for_mass_interp & age \\
\midrule
5400±110 & 1.15±0.06 & 0.51±0.03 & 16990 & -0.050000 & 0.010000 & 162.000000 \\
\bottomrule
\end{tabular}



# Plotting

In [None]:
sources = Table.read('../table_C.fits',format='fits')

# idx = 1
# i = np.where(sources['idx'] == idx)[0][0]
i = 5
idx = sources[i]['idx']

m1 = sources[i]['m1']
m1_err = sources[i]['m1_err']
m2 = sources[i]['m2']
if m2>1.25:
    m2 = sources[i]['m2'] - sources[i]['m2_err']
m2_err = sources[i]['m2_err']

parallax = sources[i]['parallax']
parallax_err = sources[i]['parallax_error']

teff1 = sources[i]['teff1']
# teff2 = sources[i]['teff2']

meta = sources[i]['[Fe/H]']

# wl, fl = plot_binary_wd(teff1,m1,m1_err,teff2,m2,m2_err,parallax,parallax_err,meta,gratings = [1105,'galex'])
plot_kurucz_vs_obs(sources,idx)
# plot_sed_ms_wd(teff1,m1,m1_err,teff2,m2,m2_err,parallax,parallax_err,meta)
# plot_binary_ms(8000,m1,m1_err,6000,m2,m2_err,parallax,parallax_err)
# plot_triple_ms(9000,m1,m1_err,3500,m2,m2_err,parallax,parallax_err)


# tbl = Table({'wavelength':wl,'flux':fl})
# cond = (tbl['wavelength'] > 925) & (tbl['wavelength'] < 80000)
# tbl = tbl[cond]
# tbl.write('../SBA28.dat',format='ascii',overwrite=True)