#  Query Gaia for WDS entries - parallelized for multiprocessing across multiple cores
#### Summer 2022 -> revised in Spring 2023
#### Daphne Zakarian

In [None]:
# conda install astroquery
!pip install astroquery

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting astroquery
  Downloading astroquery-0.4.6-py3-none-any.whl (4.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m59.9 MB/s[0m eta [36m0:00:00[0m
Collecting pyvo>=1.1
  Downloading pyvo-1.4.1-py3-none-any.whl (887 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.9/887.9 kB[0m [31m51.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting keyring>=4.0
  Downloading keyring-23.13.1-py3-none-any.whl (37 kB)
Collecting jeepney>=0.4.2
  Downloading jeepney-0.8.0-py3-none-any.whl (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.4/48.4 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jaraco.classes
  Downloading jaraco.classes-3.2.3-py3-none-any.whl (6.0 kB)
Collecting SecretStorage>=3.2
  Downloading SecretStorage-3.3.3-py3-none-any.whl (15 kB)
Installing collected packages: jeepney, jara

In [None]:

from astropy.io import ascii
from astropy.table import vstack, Table, unique
from astropy.coordinates import SkyCoord 
import astropy.units as u
from astropy import table, log
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord, Distance, Angle
from astropy.time import Time
from astropy.io import ascii
from astroquery.gaia import Gaia
from astroquery.utils.tap.model import job
from itertools import combinations
import multiprocessing
from multiprocessing import Queue, Pool, freeze_support, Process
import os
from IPython.display import display
from multiprocessing import set_start_method

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# path = 'C:/Users/sc36/Documents/DaphneUSNO/NOFS copy-20230218T215456Z-001/NOFS copy/wdstab6-27.ecsv'
# wdstab = Table.read(path, header_start=0, data_start=1) 

path = '/content/drive/MyDrive/NOFS copy/wdstab6-27.ecsv'
wdstab = Table.read(path, header_start=0, data_start=1) 

## query_gaia(coordinate, radius)

In [None]:
def query_gaia(coordinate, radius):

    # these column names list the info to pull from Gaia
    # if you change this, make sure to change the wds_to_gaia_query() function 
    # to update that info in the tables themselves!!
    column_names = ['source_id', 'ref_epoch', 'ra', 'ra_error', 'dec',
        'dec_error', 'parallax', 'parallax_error', 'parallax_over_error','pmra',
        'pmra_error', 'pmdec', 'pmdec_error',
        'radial_velocity', 'radial_velocity_error',
        'astrometric_params_solved', 'visibility_periods_used',
        'astrometric_sigma5d_max','ruwe',
        'phot_g_mean_mag', 'phot_g_mean_flux_over_error',
        'phot_bp_mean_mag', 'phot_bp_mean_flux_over_error',
        'phot_rp_mean_mag', 'phot_rp_mean_flux_over_error',
        'bp_rp','phot_bp_rp_excess_factor']
    
    # the columns have to be a string, not a list
    # this turns the column list into a string for the query
    columns = ''
    for column in column_names:
        columns += column + ', '
    columns =  columns.rstrip(columns[-4])
    columns = columns[:len(columns)-2]
    columns

    # get the degree value for coordinate and radius
    ra = coordinate.ra.deg
    dec = coordinate.dec.deg
    radius = float(radius.to_value(u.deg))

    # query base:
    query_base = """
    SELECT {columns}
    FROM gaiadr3.gaia_source
    WHERE parallax > 1
    AND parallax_over_error > 5
    AND parallax_error < 2
    AND 1 = CONTAINS(
    POINT({ra}, {dec}),
    CIRCLE(ra, dec, {rad}))

    """  



    # format the query with our specific info
    query = query_base.format(columns=columns, ra=ra, dec=dec, rad=radius)

    # make the query to gaia and save the results into astropy table
    job = Gaia.launch_job_async(query)
    job
    results = job.get_results()
    return results



## test queries for individual rows

In [None]:
# # Read in WDS (from Vayu's Lab comp)
#path = 'C:/Users/sc36/Documents/DaphneUSNO/NOFS copy-20230218T215456Z-001/NOFS copy/wdstab6-27.ecsv'
#wdstab = Table.read(path, header_start=0, data_start=1) 


rownum = 754

#read in the coordinates of the primary and secondary in WDS for the designated row number
ra1, dec1 = wdstab['RApri-prepped'][rownum], wdstab['DECpri-prepped'][rownum]
ra2, dec2 = wdstab['RAsec-prepped'][rownum], wdstab['DECsec-prepped'][rownum]
# radius is degrees
radius = 5*u.arcsec
coord1 = SkyCoord(ra=ra1 , dec = dec1, unit='deg')
myquery1 = query_gaia(coordinate=coord1, radius=radius)

radius = 5*u.arcsec
coord2 = SkyCoord(ra=ra2 , dec = dec2, unit='deg')
myquery2 = query_gaia(coordinate=coord2, radius=radius)


vstack([myquery1, myquery2])


INFO:astroquery:Query finished.


INFO: Query finished. [astroquery.utils.tap.core]


INFO:astroquery:Query finished.


INFO: Query finished. [astroquery.utils.tap.core]


source_id,ref_epoch,ra,ra_error,dec,dec_error,parallax,parallax_error,parallax_over_error,pmra,pmra_error,pmdec,pmdec_error,radial_velocity,radial_velocity_error,astrometric_params_solved,visibility_periods_used,astrometric_sigma5d_max,ruwe,phot_g_mean_mag,phot_g_mean_flux_over_error,phot_bp_mean_mag,phot_bp_mean_flux_over_error,phot_rp_mean_mag,phot_rp_mean_flux_over_error,bp_rp,phot_bp_rp_excess_factor
Unnamed: 0_level_1,yr,deg,mas,deg,mas,mas,mas,Unnamed: 8_level_1,mas / yr,mas / yr,mas / yr,mas / yr,km / s,km / s,Unnamed: 15_level_1,Unnamed: 16_level_1,mas,Unnamed: 18_level_1,mag,Unnamed: 20_level_1,mag,Unnamed: 22_level_1,mag,Unnamed: 24_level_1,mag,Unnamed: 26_level_1
int64,float64,float64,float32,float64,float32,float64,float32,float32,float64,float32,float64,float32,float32,float32,int16,int16,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32


## wds_in_gaia_query(core_num, total_cores) --- query WDS entries in Gaia and save results in a table

In [None]:
def wds_in_gaia_query(core_num, total_cores): # core num starts at 0 
    
    # # Read in WDS from Vayu's Lab
    path = 'C:/Users/sc36/Documents/DaphneUSNO/NOFS copy-20230218T215456Z-001/NOFS copy/wdstab6-27.ecsv'
    wdstab = Table.read(path, header_start=0, data_start=1) 

    # total number of queries will be the number of wds entries that we look at
    total_num_queries = len(wdstab)
    
    # find approx # of queries per core... ignoring the fraction 
    queries_per_core = total_num_queries // total_cores
    leftover_rows = total_num_queries % total_cores


    # make a list of the start and end row variables0
    start_row_list = []
    end_row_list = []

    # make a list to get the start and end row for each process
    rownum_counter = 0
    for core in range(total_cores):
        start_row_list.append(rownum_counter)
        rownum_counter += queries_per_core
        if core == total_cores - 1:
            end_row_list.append(total_num_queries)
        else:
            end_row_list.append(rownum_counter)
    end_row_list[-1] = end_row_list[-1] + leftover_rows
    
    
    # define start and end row of wds
    # start row is included in query, but end row is not included in the range
    wds_start_row =  start_row_list[core_num]
    wds_end_row = end_row_list[core_num]
    
    

    # these are the column names that have a number data type... 
    # the source ids need to stay as strings so I add those separately
    num_column_names = ['ref_epoch', 'ra', 'ra_error', 'dec',
                    'dec_error', 'parallax', 'parallax_error', 'parallax_over_error','pmra',
                    'pmra_error', 'pmdec', 'pmdec_error',
                    'radial_velocity', 'radial_velocity_error',
                    'astrometric_params_solved', 'visibility_periods_used',
                    'astrometric_sigma5d_max','ruwe',
                    'phot_g_mean_mag', 'phot_g_mean_flux_over_error',
                    'phot_bp_mean_mag', 'phot_bp_mean_flux_over_error',
                    'phot_rp_mean_mag', 'phot_rp_mean_flux_over_error',
                    'bp_rp','phot_bp_rp_excess_factor']

       

    

    # we will have a pair of stars for each column
    # put the parameters in a dictionary with suffixes _a and _b to name columns accordingly
    colname_dictionary = {}

    for column in num_column_names:
        colname_dictionary['{0}_a'.format(column)] = 0
        colname_dictionary['{0}_b'.format(column)] = 0

    colnames = []
    for entry in colname_dictionary:
        colnames.append(entry)
    
    
        
        
    """ BUILD OUTPUT TABLES """
    
    # query results table will have all info for a pair of stars in one row
    query_results_table = Table(names=colnames)
    
    
    # add the wds identifier column and source id columns (doesn't work until I add one row to the table)
    query_results_table.add_row()
    query_results_table.add_column('                              ', name = 'wds_identifier', index = 0)
    query_results_table.add_column('                              ', name = 'source_id_a', index = 1)
    query_results_table.add_column('                              ', name = 'source_id_b', index = 2)

    
    # remove that first row -- the loop will add rows as needed
    query_results_table.remove_row(0)
    
    
    
    # index error wds info:
    index_error_queries = Table(names = ('wds_identifier', 'wds_rownum'), dtype = ('a30', 'f8'))
    
    # unknown error wds info:
    unknown_error_queries = Table(names = ('wds_identifier', 'wds_rownum'), dtype = ('a30', 'f8'))
    
    
    
    
    # initialize row numbers for each output table:
    query_results_table_rownum = 0
    index_error_queries_rownum = 0
    unknown_error_queries_rownum = 0
        
    
    
  
    
    
    # initialize wds identifier
    wds_identifier = ''
    
    

    for rownum in range(wds_start_row, wds_end_row):   
        
        # if the previous WDS identifier (from last iteration of loop) is the same is current one,
        # this row was already accounted for in that query
        if wdstab['WDS Identifier'][rownum] == wds_identifier:
            pass
        
        else:    
            # read in the wds identifier so we know which object is queried
            wds_identifier = wdstab['WDS Identifier'][rownum]
            
            # if there are multiple columns with same WDS identifier, 
            # query all of those objects and add them to gaiaresults list
            for shared_id_rownum in range(rownum, wds_end_row):
                if wdstab['WDS Identifier'][shared_id_rownum] == wds_identifier:
            
                    print('\n core # ', core_num, 'of ', total_cores, 'cores   --- row number: ', rownum)
                    """ make the 2 queries for given WDS row """
                    # use query_gaia(coordinate, radius) to query gaia server
                    ra1, dec1, ra2, dec2 =wdstab['RApri-prepped'][rownum], wdstab['DECpri-prepped'][rownum], wdstab['RAsec-prepped'][rownum], wdstab['DECsec-prepped'][rownum]
                    
                    radius1 = 5*u.arcsec
                    coord = SkyCoord(ra=ra1 , dec = dec1, unit='deg')
                    myquery1 = query_gaia(coordinate=coord, radius=radius1)
                    
                    radius2 = 5*u.arcsec
                    coord = SkyCoord(ra=ra2 , dec = dec2, unit='deg')
                    myquery2 = query_gaia(coordinate=coord, radius=radius2)
                    
                    """ VERTICALLY STACK ALL QUERIES TO CREATE A LIST WITH ALL QUERIES FROM 1 WDS ROW """
                    
                    # first query for this WDS identifier: just add query 1 and 2 to list
                    
                    if len(myquery1) + len(myquery2) == 0: 
                        index_error_queries.add_row()
                        index_error_queries['wds_identifier'][index_error_queries_rownum] = wds_identifier
                        index_error_queries['wds_rownum'][index_error_queries_rownum] = rownum
                        index_error_queries_rownum +=1

                        # checkpoint
                        # print('index error table updated')
                        pass
                    elif shared_id_rownum == rownum:
                        gaiaresults = vstack([myquery1, myquery2])
                        
                    # then, keep adding the new queries to the existing gaiaresults list
                    else:
                        gaiaresults = vstack([gaiaresults, myquery1, myquery2])
                        
              
                    
                # if WDS identifiers don't match, move on
                else:
                    pass
                

            try:

                """ REMOVE DUPLICATES FROM GAIA RESULTS TABLE """

                # checkpoint
                # print('length of gaiaresults is', len(gaiaresults))

                gaiaresults = unique(gaiaresults, keep = 'first', silent = 'True')

                # checkpoint
                # print('duplicates_removed')
                # print('length of gaiaresults is', len(gaiaresults))

                # save all query results where less than two unique objects are found
                # to index error query table

                if len(gaiaresults) <= 1:
                    index_error_queries.add_row()
                    index_error_queries['wds_identifier'][index_error_queries_rownum] = wds_identifier
                    index_error_queries['wds_rownum'][index_error_queries_rownum] = rownum
                    index_error_queries_rownum +=1

                    # checkpoint
                    # print('index error table updated')
                    pass


                else:


                    """ CROSS CHECK EACH ENTRY WITH EACH OTHER """
                    # avoid repeat comparisons

                    # make a list of every unique combination of two objects in my list
                    # this will be a comma separate string of source ids from Gaia
                    L = gaiaresults['source_id']
                    combolist = [",".join(map(str, comb)) for comb in combinations(L, 2)]

                    # checkpoint
                    # print('cross check complete')


                    #make source id column the index for gaiaresults table
                    # this allows us to return a row by searching the source id 
                    gaiaresults.add_index('source_id')


                    # use the list of unique combinations and find both of those rows
                    # then, compare them



                    for combination in combolist:


                        # the combination is a comma separated entry of two source ids -- unique combo
                        # then, split them up so I can call to the data about each specific target in the combo
                        # the source id is the index for my gaiaresults table, so I can call to the target row using the id
                        query_a, query_b = combination.split(',')
                        row_a = gaiaresults.loc[int(query_a)]
                        row_b = gaiaresults.loc[int(query_b)]

                        # checkpoint
                        # print('components assigned')
                        # print(gaiaresults)
                        
                        """ READ IN THE RELEVANT INFO (source id and parallax): """

                        # read in the parameters for object a and b
                        # put the parameters in a dictionary with suffixes _a and _b accordingly
                        parameter_dictionary = {}

                        for column in query_results_table.colnames:
                            if column == 'wds_identifier':
                                parameter_dictionary['wds_identifier'] = wdstab[rownum]['WDS Identifier']

                            elif column.endswith('_a') == True:
                                param_len = len(column)
                                parameter_dictionary['{0}'.format(str(column))] = row_a[column[:param_len - 2]]
                            elif column.endswith('_b') == True:
                                param_len = len(column)
                                parameter_dictionary['{0}'.format(str(column))] = row_b[column[:param_len - 2]]



                        query_results_table.add_row()

                        for entry in parameter_dictionary:
                            query_results_table[entry][query_results_table_rownum] = parameter_dictionary[entry]


                        query_results_table_rownum +=1

                        # checkpoint
                        # print('query_results_table updated')




            except:
                # make a list of objects with any other error:           
                unknown_error_queries.add_row()
                unknown_error_queries['wds_identifier'][unknown_error_queries_rownum] = wds_identifier
                unknown_error_queries['wds_rownum'][unknown_error_queries_rownum] = rownum
                unknown_error_queries_rownum +=1
                
                #checkpoint
                # print('unknown error')
                
                pass


                
    save_path = 'C:/Users/sc36/Documents/DaphneUSNO/NOFS copy-20230218T215456Z-001/NOFS copy/QueryResults'
    ascii.write(query_results_table, '{path}/query_results_table_c{core}.ecsv'.format(path = save_path, core = core_num), format='ecsv',overwrite=True)
    ascii.write(query_results_table, '{path}/query_results_table_c{core}.csv'.format(path = save_path, core = core_num), format='csv',overwrite=True)
    
    ascii.write(index_error_queries, '{path}/index_error_queries_c{core}.ecsv'.format(path = save_path, core = core_num), format='ecsv',overwrite=True)
    ascii.write(index_error_queries, '{path}/index_error_queries_c{core}.csv'.format(path = save_path, core = core_num), format='csv',overwrite=True)
    
    ascii.write(unknown_error_queries, '{path}/unknown_error_queries_c{core}.ecsv'.format(path = save_path, core = core_num), format='ecsv',overwrite=True)
    ascii.write(unknown_error_queries, '{path}/unknown_error_queries_c{core}.csv'.format(path = save_path, core = core_num), format='csv',overwrite=True)
    

In [None]:


# wds_in_gaia_query(0,30000)
# wds_in_gaia_query(1,30000)
# wds_in_gaia_query(2,30000)
# wds_in_gaia_query(3,30000)




### Dividing up the WDS for multiprocessing
##### this is incorporated in the main function, just rewritten here for checks

In [None]:
# Prepare for multiprocessing
total_cores = 2

# total number of queries will be the number of wds entries that we look at
total_num_queries = len(wdstab)

# find approx # of queries per core... ignoring the fraction 
queries_per_core = total_num_queries // total_cores
leftover_rows = total_num_queries % total_cores

# make a list of the start and end row variables0
start_row_list = []
end_row_list = []

# make a list to get the start and end row for each process
rownum_counter = 0
for core in range(total_cores):
    start_row_list.append(rownum_counter)
    rownum_counter += queries_per_core
    if core == total_cores - 1:
        end_row_list.append(total_num_queries)
    else:
        end_row_list.append(rownum_counter)

end_row_list[-1] = end_row_list[-1] + leftover_rows




## Initiate Gaia Query with multiprocessing

In [None]:

def initiate_gaia_query(total_cores):
    processes=[]
    
    queue = Queue()
    for core_num in range(total_cores):
        p = multiprocessing.Process(target = wds_in_gaia_query, args = (core_num, total_cores))
        p.start()
        processes.append(p)
    
    for p in processes:
        p.join()


## Test query multiprocessing

In [None]:
def initiate_test_gaia_query():

        
    processes=[]
    
    num_of_processes = 4
    divide_wds = 3000
    
    queue = Queue()
    for core_num in range(num_of_processes):
        print('process initiated: core', core_num)
        p = multiprocessing.Process(target = wds_in_gaia_query, args = (core_num, divide_wds))
        p.start()
        processes.append(p)
        

    for p in processes:
        p.join()


In [None]:
initiate_test_gaia_query()

process initiated: core 0
process initiated: core 1
process initiated: core 2
process initiated: core 3


In [None]:


# if __name__ == '__main__':
#     pool = Pool()
#     divide_wds = 3000                         # Create a multiprocessing Pool
#     for core_num in range(4):
#         pool.map(wds_in_gaia_query,core_num, divide_wds)  # process data_inputs iterable with pool 
        
    

In [None]:
if __name__ == '__main__':
    list_of_cores = [0,1,2]
    for i in list_of_cores:
        p = Process(target=wds_in_gaia_query, args = (i, 50000,))
        p.start()
        print('Waiting for simple func to end')
        p.join()


# ### stack output files 

In [None]:
# file_dictionary = {}

# total_cores = 4

# for core_num in range(total_cores):

#             file_dictionary['query_results_table_c{0}'.format(core_num)] = 0
#             file_dictionary['index_error_queries_c{0}'.format(core_num)] = 0
            
            

# directory = 'C:/Users/sc36/Documents/DaphneUSNO/NOFS copy-20230218T215456Z-001/NOFS copy/QueryResults'

# for file in file_dictionary:
    
#     file_dictionary[file] = Table.read('{0}/{1}.ecsv'.format(directory, file), header_start=0, data_start=1)
    

    
    

# # vertically stack all 20 sections of each table


# query_results_table_list = []
# index_error_queries_list = []


# for file in file_dictionary:
#     if file.startswith('query_results_table_c'):
#         query_results_table_list.append(file_dictionary[file])
#     elif file.startswith('index_error_queries_c'):
#         index_error_queries_list.append(file_dictionary[file])



# stack_query_results_table = vstack(query_results_table_list)
# stack_index_error_queries = vstack(index_error_queries_list)


# ascii.write(stack_query_results_table, '{0}/stack_query_results_table.ecsv'.format(directory), format='ecsv')
# ascii.write(stack_query_results_table, '{0}/stack_query_results_table.csv'.format(directory), format='csv')


# ascii.write(stack_index_error_queries, '{0}/stack_index_error_queries.ecsv'.format(directory), format='ecsv')
# ascii.write(stack_index_error_queries, '{0}/stack_index_error_queries.csv'.format(directory), format='csv')


# qrt ='{0}/stack_query_results_table.ecsv'.format(directory) 
# ie = '{0}/stack_index_error_queries.ecsv'.format(directory)
# stack_query_results_table = Table.read(qrt, header_start=0, data_start=1)
# stack_index_error_queries = Table.read(ie, header_start=0, data_start=1)





# # In[13]: