In [4]:
from astroquery.gaia import Gaia
from typing import List
import os

## Resolve source

In [5]:
def resolve(id:str=None, coords:List=None):
    
    if id:
        job = Gaia.launch_job_async(f"SELECT gaia3.source_id, gaia3.ra, gaia3.dec, gaia3.parallax, gaia3.parallax_over_error,gaia3.ruwe, gaia3.has_xp_sampled, gaia3ap.classprob_dsc_combmod_star, gaia3ap.classprob_dsc_specmod_star \
                            FROM gaiadr3.gaia_source_lite AS gaia3 \
                            JOIN gaiadr3.astrophysical_parameters AS gaia3ap \
                            ON gaia3.source_id = gaia3ap.source_id \
                            WHERE gaia3.source_id = {source_id}")
        results = job.get_results()
    elif coords:
        job = Gaia.launch_job_async(f"SELECT gaia3.source_id, gaia3.ra, gaia3.dec, gaia3.parallax, gaia3.parallax_over_error,gaia3.ruwe, gaia3.has_xp_sampled, gaia3ap.classprob_dsc_combmod_star, gaia3ap.classprob_dsc_specmod_star \
                    FROM gaiadr3.gaia_source_lite AS gaia3 \
                    JOIN gaiadr3.astrophysical_parameters AS gaia3ap \
                    ON gaia3.source_id = gaia3ap.source_id \
                    WHERE gaia3.ra = {coords[0]} AND gaia3.dec={coords[1]}")
        results = job.get_results()
    else:
        raise ValueError("Either 'id' or 'coords' must be provided!")
    
    if results.to_pandas().empty:
            print("No sources found!")
    else:
        _check_quality(results.to_pandas())
    return results.to_pandas()

def _check_quality(results):
    if (results['ruwe'] > 1.4).any() or results['parallax'].isnull().any() or (results['parallax_over_error'] <= 3).any():
        print("The source has poor parameters, it might not be properly resolved.")

    if (results['classprob_dsc_combmod_star']<0.5).any() or (results['classprob_dsc_specmod_star']<0.5).any():
        print("The source is most likely not a star.")
    
    if (results['has_xp_sampled'] != True).any():
        print("The source has no BP-RP spectrum data in Gaia Data Release 3!")

# test values
source_id = '4111834567779557376'
ra = '256.5229102004341'
dec = '-26.580565130784702'

results = resolve(id=source_id)#None,coords=list((ra,dec)))

INFO: Query finished. [astroquery.utils.tap.core]


In [6]:
results

Unnamed: 0,SOURCE_ID,ra,dec,parallax,parallax_over_error,ruwe,has_xp_sampled,classprob_dsc_combmod_star,classprob_dsc_specmod_star
0,4111834567779557376,256.52291,-26.580565,1.153767,47.893497,0.836915,True,0.99992,0.992043


## Download data

In [13]:
def pull_data(results):
    retrieval_type = 'XP_SAMPLED'  
    data_structure = 'INDIVIDUAL'
    data_release   = 'Gaia DR3'
    dl_key         = f'{retrieval_type}_{data_structure}.xml'

    datalink  = Gaia.load_data(ids=results['SOURCE_ID'], data_release = data_release, retrieval_type=retrieval_type, format = 'csv', data_structure = data_structure)
    
    for dl_key in datalink.keys():
        if 'XP_SAMPLED' in dl_key: 
            product = datalink[dl_key][0]
            
            file_name = f"{dl_key.replace('.xml', '').replace(' ','_').replace('-','_')}"

            print(f'Writing table as: {file_name}')
            if os.path.exists('./temp'):
                product.write('./temp/'+file_name, format = 'csv', overwrite = True)
            else:
                os.makedirs('./temp')
                product.write('./temp/'+file_name, format = 'csv', overwrite = True)
    return

pull_data(results)

## Inference and delete data

In [45]:
import pandas as pd
import glob
import torch
import shutil

data_dir = './temp'
file = '*.csv'

# Read the data
csv_files = glob.glob(f"{data_dir}/{file}")
if csv_files:
    data = pd.read_csv(csv_files[0])
    X = data['flux'].to_numpy()

    def inference(model, X):
        model.eval()
        with torch.no_grad():
            output = model(X.unsqueeze(1))
            prob = torch.sigmoid(output)
            prediction = torch.round(prob).numpy().astype(float)
            return prediction

    # Load the model
    model = torch.jit.load('../models/cnn_ensemble.pth', map_location=torch.device('cpu'))
    
    # Perform inference
    prediction = inference(model, torch.from_numpy(X).float().unsqueeze(0))
    print(prediction)

    # Delete the entire temp directory
    try:
        shutil.rmtree(data_dir)
        print(f"Deleted {data_dir} directory and its contents.")
    except Exception as e:
        print(f"Error deleting {data_dir} directory: {e}")

else:
    print("No CSV files found in the temp directory.")

[[0.]]
Deleted ./temp directory and its contents.
