In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split

import glob
from astropy.io.votable import parse_single_table
import numpy as np

In [3]:
flux_df = pd.read_parquet('../data/Gaia DR3/combined/df_xp.parquet')
flux_data_files = glob.glob('../data/Gaia DR3/spectra/XP/*.vot')

source_id_to_flux = {}

for f in flux_data_files:
    source_id = int(f.split('_')[-1].split('.')[0])
    
    data = parse_single_table(f).to_table().to_pandas()
    flux = data['flux'].to_numpy()
    source_id_to_flux[source_id] = flux
flux_df['flux'] = flux_df['source_id'].map(source_id_to_flux)

flux_df.to_parquet('../data/Gaia DR3/combined/massive_xp.parquet')

In [4]:
flux_df.shape

(23636, 7)

In [6]:
flux_df.columns

Index(['source_id', 'teff_gspphot', 'logg_gspphot', 'mh_gspphot',
       'spectraltype_esphs', 'Cat', 'flux'],
      dtype='object')

In [5]:
flux_df['Cat'].value_counts()

Cat
M     11909
LM    11727
Name: count, dtype: int64

In [None]:
ap_df = flux_df.drop(columns=['flux', 'Cat', 'spectraltype_esphs'])
bin_cat_df = flux_df.drop(columns=[ 'teff_gspphot', 'logg_gspphot', 'mh_gspphot'])

In [27]:
missing_rows = bin_cat_df[bin_cat_df.isnull().any(axis=1)]
missing_rows

Unnamed: 0,source_id,spectraltype_esphs,Cat,flux


In [14]:
train_df, test_df = train_test_split(ap_df, test_size=0.2)

train_df.reset_index(drop=True).to_parquet('../data/Gaia DR3/splits/ap/ap_train.parquet')
test_df.reset_index(drop=True).to_parquet('../data/Gaia DR3/splits/ap/ap_test.parquet')
train_df.reset_index(drop=True).to_csv('../data/Gaia DR3/splits/ap/ap_train.csv')
test_df.reset_index(drop=True).to_csv('../data/Gaia DR3/splits/ap/ap_test.csv')

In [29]:
train_df, test_df = train_test_split(bin_cat_df, test_size=0.2, stratify=bin_cat_df['Cat'])

train_df.reset_index(drop=True).to_parquet('../data/Gaia DR3/splits/bin_clf/bin_cat_train.parquet')
test_df.reset_index(drop=True).to_parquet('../data/Gaia DR3/splits/bin_clf/bin_cat_test.parquet')
train_df.reset_index(drop=True).to_csv('../data/Gaia DR3/splits/bin_clf/bin_cat_train.csv')
test_df.reset_index(drop=True).to_csv('../data/Gaia DR3/splits/bin_clf/bin_cat_test.csv')