In [1]:
from astroquery.gaia import Gaia
import matplotlib.pyplot as plt
import math

In [2]:
def extract_dl_ind(datalink_dict, key, figsize = [15,5], fontsize = 12, linewidth = 2, show_legend = True, show_grid = True):
    ""
    "Extract individual DataLink products and export them to an Astropy Table"
    ""
    dl_out  = datalink_dict[key][0].to_table()
    if 'time' in dl_out.keys():
        plot_e_phot(dl_out, colours  = ['green', 'red', 'blue'], title = 'Epoch photometry', fontsize = fontsize, show_legend = show_legend, show_grid = show_grid, figsize = figsize)
    if 'wavelength' in dl_out.keys():
        if len(dl_out) == 343:  title = 'XP Sampled'
        if len(dl_out) == 2401: title = 'RVS'
        plot_sampled_spec(dl_out, color = 'blue', title = title, fontsize = fontsize, show_legend = False, show_grid = show_grid, linewidth = linewidth, legend = '', figsize = figsize)
    return dl_out


def plot_e_phot(inp_table, colours  = ['green', 'red', 'blue'], title = 'Epoch photometry', fontsize = 12, show_legend = True, show_grid = True, figsize = [15,5]):
    ""
    "Epoch photometry plotter. 'inp_table' MUST be an Astropy-table object."
    ""
    fig      = plt.figure(figsize=figsize)
    xlabel   = f'JD date [{inp_table["time"].unit}]'
    ylabel   = f'magnitude [{inp_table["mag"].unit}]'
    gbands   = ['G', 'RP', 'BP']
    colours  = iter(colours)

    plt.gca().invert_yaxis()
    for band in gbands:
        phot_set = inp_table[inp_table['band'] == band]
        plt.plot(phot_set['time'], phot_set['mag'], 'o', label = band, color = next(colours))
    make_canvas(title = title, xlabel = xlabel, ylabel = ylabel, fontsize= fontsize, show_legend=show_legend, show_grid = show_grid)
    plt.show()


def plot_sampled_spec(inp_table, color = 'blue', title = '', fontsize = 14, show_legend = True, show_grid = True, linewidth = 2, legend = '', figsize = [12,4], show_plot = True):
    ""
    "RVS & XP sampled spectrum plotter. 'inp_table' MUST be an Astropy-table object."
    ""
    if show_plot:
        fig      = plt.figure(figsize=figsize)
    xlabel   = f'Wavelength [{inp_table["wavelength"].unit}]'
    ylabel   = f'Flux [{inp_table["flux"].unit}]'
    plt.plot(inp_table['wavelength'], inp_table['flux'], '-', linewidth = linewidth, label = legend)
    make_canvas(title = title, xlabel = xlabel, ylabel = ylabel, fontsize= fontsize, show_legend=show_legend, show_grid = show_grid)
    if show_plot:
        plt.show()


def make_canvas(title = '', xlabel = '', ylabel = '', show_grid = False, show_legend = False, fontsize = 12):
    ""
    "Create generic canvas for plots"
    ""
    plt.title(title,    fontsize = fontsize)
    plt.xlabel(xlabel,  fontsize = fontsize)
    plt.ylabel(ylabel , fontsize = fontsize)
    plt.xticks(fontsize = fontsize)
    plt.yticks(fontsize = fontsize)
    if show_grid:
        plt.grid()
    if show_legend:
        plt.legend(fontsize = fontsize*0.75)

In [7]:
#query = f"SELECT * FROM gaiadr3.gaia_source WHERE has_epoch_photometry = 'TRUE' and source_id = 1035533795140608"
query = f"SELECT * FROM gaiadr3.vari_classifier_result WHERE best_class_name ='LPV'"

job     = Gaia.launch_job_async(query)
results = job.get_results()
print(f'Table size (rows): {len(results)}')

INFO: Query finished. [astroquery.utils.tap.core]
Table size (rows): 2325775


In [8]:
results

solution_id,source_id,classifier_name,best_class_name,best_class_score
int64,int64,str15,str30,float32
375316653866487564,5902285104890627584,nTransits:5+,LPV,0.038731564
375316653866487564,5902285315377331200,nTransits:5+,LPV,0.9638877
375316653866487564,5902291225252490240,nTransits:5+,LPV,0.92284286
375316653866487564,5902293523025573760,nTransits:5+,LPV,0.978285
375316653866487564,5902296168725541120,nTransits:5+,LPV,0.9871295
375316653866487564,5902298853114705664,nTransits:5+,LPV,0.9231881
375316653866487564,5902299540309456128,nTransits:5+,LPV,0.9609122
375316653866487564,5902309023576226688,nTransits:5+,LPV,0.06635127
375316653866487564,5902311909815916032,nTransits:5+,LPV,0.91618407
375316653866487564,5902316926337867776,nTransits:5+,LPV,0.9341271


In [5]:
results.to_pandas().to_csv('vari_long_period_variable_class_score.csv')