# Training script for astroNN VAC DR17

This notebook contains training script for astroNN VAC DR17 models (i.e. stellar parameters, distances and ages)

Please notics that all the required files can be generated by the scripts provided in the upper level, so you have to move the files to this directory or move the notebooks from this directory to upper level.

## Chemical abundances

In [None]:
import numpy as np
import tensorflow as tf
from astropy.io import fits

from astroNN.apogee import aspcap_mask
from astroNN.models import ApogeeBCNNCensored, ApogeeBCNN
from astroNN.apogee import allstar
from astroNN.datasets import xmatch
from astroNN.nn.losses import mean_absolute_error, mean_error, mean_absolute_percentage_error

f = fits.getdata(allstar(dr=17))
f16 = fits.getdata(allstar(dr=16))

contspec_mask = fits.open("contspec_dr17_synspec.fits")[1].data
contspec = fits.getdata("contspec_dr17_synspec.fits")

# loader.target = ['teff', 'logg', 'C', 'C1', 'N', 'O', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'K',
#                  'Ca', 'Ti', 'Ti2', 'V', 'Cr', 'Mn', 'Fe','Co', 'Ni']

vscatter = f['VSCATTER']
SNR = f['SNR']
location_id = f['LOCATION_ID']
teff = f['PARAM'][:, 0]
fe = f['X_H'][:, 17]

starflag = f['STARFLAG']
aspcapflag = f['ASPCAPFLAG']

good_idx = (np.array(contspec_mask, bool) & (starflag==0) & (aspcapflag==0) & (vscatter<1) & 
            (SNR>200) & (f['TEFF']>3000) & (f['TEFF']<6500) & (np.invert(np.isnan(fe))) & 
            (np.invert(np.isnan(f['X_H'][:, 5]))))

ra = f["RA"][good_idx]
dec = f["DEC"][good_idx]

idx1, idx2, sep = xmatch(ra, dec, f16["RA"], f16["DEC"])

labels = np.array(np.hstack([np.stack([f["TEFF"], f['LOGG']]).T, f["X_H"][:, :20]])[good_idx], dtype=np.float64)
labels_err = np.array(np.hstack([np.stack([f["TEFF_ERR"], f['LOGG_ERR']]).T, f["X_H_ERR"][:, :20]])[good_idx], dtype=np.float64)
input_spec = contspec[good_idx]

bad_aspcap_params = (f["ELEMFLAG"][:, :20][good_idx] != 0)
labels[:, 2:][bad_aspcap_params] = -9999.
labels_err[:, 2:][bad_aspcap_params] = -9999.

# inject P_H
labels[idx1, 10] = f16["X_H"][:, 8][idx2]
labels_err[idx1, 10] = f16["X_H_ERR"][:, 8][idx2]

# setup astroNN mdoel instance
bcnn = ApogeeBCNNCensored()
bcnn.batch_size = 256
bcnn.input_norm_mode = 3  # center label but not scale pixel
bcnn.metrics = [mean_absolute_error, mean_error, mean_absolute_percentage_error]
bcnn.num_hidden = [256, 96, 32, 16, 2]
bcnn.max_epochs = 60
bcnn.reduce_lr_patience = 5
bcnn.autosave=True
# train
bcnn.fit(input_spec, labels, labels_err=labels_err)  

## Distances

In [None]:
import h5py
import numpy as np
from astropy.io import fits

from astroNN.apogee import allstar
from astroNN.gaia import mag_to_fakemag, extinction_correction, fakemag_to_logsol


file = fits.open("contspec_dr17_synspec.fits")
allstar_file = fits.getdata(allstar(dr=17))
gaia_data_file = fits.getdata("apogeedr17_syncspec_gaiaedr3_xmatch.fits")
all_spec = file[0].data
good_flag = file[1].data

###### Zero-point correction from Gaia ######
from zero_point import zpt

zpt.load_tables()

good_idx = np.where(((gaia_data_file["astrometric_params_solved"]==31) | (gaia_data_file["astrometric_params_solved"]==95)) & 
                    (gaia_data_file["phot_g_mean_mag"]<21) & (6<gaia_data_file["phot_g_mean_mag"]))[0]

zp = zpt.get_zpt(gaia_data_file["phot_g_mean_mag"][good_idx], 
                 gaia_data_file["nu_eff_used_in_astrometry"][good_idx], 
                 gaia_data_file["pseudocolour"][good_idx], 
                 gaia_data_file["ecl_lat"][good_idx], 
                 gaia_data_file["astrometric_params_solved"][good_idx])

# use median zero-point for all stars by default
zp_row_matched = np.ones(len(allstar_file)) * np.median(zp)

zp_row_matched[good_idx] = zp
###### Zero-point correction from Gaia ######

extinction = allstar_file['AK_TARG']
extinction[(extinction < 0.) & np.isnan(extinction)] = -9999  # assume corrupted extinction if negative extinction
ra = gaia_data_file['RA']
dec = gaia_data_file['DEC']
parallax = gaia_data_file['parallax']
parallax_w_zp = gaia_data_file['parallax'] - zp_row_matched
parallax_error = gaia_data_file['parallax_error']

# if extinction_method is IRAC then flag it as good, otherwise not good
extinction_method = np.zeros_like(ra)
extinction_method[allstar_file['AK_TARG_METHOD'] == 'RJCE_IRAC'] = 1

corrected_K = extinction_correction(allstar_file['K'], extinction)
fakemag, fakemag_error = mag_to_fakemag(corrected_K, parallax, parallax_error)
fakemag_w_zp, fakemag_w_zp_error = mag_to_fakemag(corrected_K, parallax_w_zp, parallax_error)
logsol = fakemag_to_logsol(fakemag)
logsol_w_zp = fakemag_to_logsol(fakemag_w_zp)

# cutting criteria for training set
good_idx = ((~np.isnan(parallax)) & (parallax_error < 0.1) & (parallax < 1e10) & (allstar_file['SNR'] > 200) &
            (~np.isnan(allstar_file['K'])) & (allstar_file['K']<90) & (gaia_data_file['ruwe'] < 1.4) & (fakemag != -9999.) & 
            (good_flag==1) & (allstar_file['vscatter'] < 1.) & (allstar_file['STARFLAG'] == 0) & ((logsol > 0) | (parallax < 0)) & 
            (gaia_data_file['ipd_frac_multi_peak'] <= 2) & (gaia_data_file['ipd_gof_harmonic_amplitude'] < 0.1))

# cutting criteria for testing set
good_test_idx = ((~np.isnan(parallax)) & (parallax_error < 0.1) & (parallax < 1e10) & (allstar_file['SNR'] < 200) & 
                 (~np.isnan(allstar_file['K'])) & (allstar_file['K']<90) & (gaia_data_file['ruwe'] < 1.4) & (fakemag != -9999.) &
                 (good_flag==1) & (allstar_file['vscatter'] < 1.) & (allstar_file['STARFLAG'] == 0) & ((logsol > 0) | (parallax < 0)) & 
                 (gaia_data_file['ipd_frac_multi_peak'] <= 2) & (gaia_data_file['ipd_gof_harmonic_amplitude'] < 0.1))

print("Training Set Spectra: ", np.sum(good_idx))
print("Low SNR Combined Spectra Testing Set Spectra: ", np.sum(good_test_idx))

h5f = h5py.File('gaia_edr3_dr17_syncspec_train.h5', 'w')
h5f.create_dataset('spectra', data=all_spec[good_idx])
h5f.create_dataset('RA', data=allstar_file['RA'][good_idx])
h5f.create_dataset('DEC', data=allstar_file['DEC'][good_idx])
h5f.create_dataset('SNR', data=allstar_file['SNR'][good_idx])
h5f.create_dataset('allstar_idx', data=np.arange(allstar_file['RA'].shape[0])[good_idx])
h5f.create_dataset('ASPCAP_TEFF', data=allstar_file['TEFF'][good_idx])
h5f.create_dataset('ASPCAP_LOGG', data=allstar_file['LOGG'][good_idx])
h5f.create_dataset('fakemag', data=fakemag[good_idx])
h5f.create_dataset('fakemag_err', data=fakemag_error[good_idx])
h5f.create_dataset('fakemag_w_zp', data=fakemag_w_zp[good_idx])
h5f.create_dataset('fakemag_w_zp_err', data=fakemag_w_zp_error[good_idx])
h5f.create_dataset('corrected_K', data=corrected_K[good_idx])  # extinction corrected
h5f.create_dataset('extinction', data=extinction[good_idx])
h5f.create_dataset('extinction_method', data=extinction_method[good_idx])
h5f.create_dataset('parallax', data=parallax[good_idx])
h5f.create_dataset('parallax_err', data=parallax_error[good_idx])
h5f.create_dataset('parallax_w_zp', data=parallax_w_zp[good_idx])
h5f.create_dataset('bp_rp', data=gaia_data_file['bp_rp'][good_idx])
h5f.create_dataset('phot_g_mean_mag', data=gaia_data_file['phot_g_mean_mag'][good_idx])
h5f.close()

h5f = h5py.File('gaia_edr3_dr17_syncspec_test.h5', 'w')
h5f.create_dataset('spectra', data=all_spec[good_test_idx])
h5f.create_dataset('RA', data=allstar_file['RA'][good_test_idx])
h5f.create_dataset('DEC', data=allstar_file['DEC'][good_test_idx])
h5f.create_dataset('SNR', data=allstar_file['SNR'][good_test_idx])
h5f.create_dataset('allstar_idx', data=np.arange(allstar_file['RA'].shape[0])[good_test_idx])
h5f.create_dataset('ASPCAP_TEFF', data=allstar_file['TEFF'][good_test_idx])
h5f.create_dataset('ASPCAP_LOGG', data=allstar_file['LOGG'][good_test_idx])
h5f.create_dataset('fakemag', data=fakemag[good_test_idx])
h5f.create_dataset('fakemag_err', data=fakemag_error[good_test_idx])
h5f.create_dataset('fakemag_w_zp', data=fakemag_w_zp[good_test_idx])
h5f.create_dataset('fakemag_w_zp_err', data=fakemag_w_zp_error[good_test_idx])
h5f.create_dataset('corrected_K', data=corrected_K[good_test_idx])  # extinction corrected
h5f.create_dataset('extinction', data=extinction[good_test_idx])
h5f.create_dataset('extinction_method', data=extinction_method[good_test_idx])
h5f.create_dataset('parallax', data=parallax[good_test_idx])
h5f.create_dataset('parallax_err', data=parallax_error[good_test_idx])
h5f.create_dataset('parallax_w_zp', data=parallax_w_zp[good_test_idx])
h5f.create_dataset('bp_rp', data=gaia_data_file['bp_rp'][good_test_idx])
h5f.create_dataset('phot_g_mean_mag', data=gaia_data_file['phot_g_mean_mag'][good_test_idx])
h5f.close()

import h5py
import numpy as np
import astropy.units as u
import tensorflow as tf 

from astroNN.nn.losses import mean_absolute_error, mean_error, mean_absolute_percentage_error
from astroNN.models import ApogeeBCNN
from astroNN.nn.callbacks import ErrorOnNaN
from astroNN.gaia import mag_to_fakemag

with h5py.File('gaia_edr3_dr17_syncspec_train.h5') as F:  # ensure the file will be cleaned up
    parallax = np.array(F['parallax_w_zp'])
    parallax_error = np.array(F['parallax_err'])
    fakemag = np.array(F['fakemag_w_zp'])
    fakemag_err = np.array(F['fakemag_w_zp_err'])
    spectra = np.array(F['spectra'])
    Kcorr = np.array(F['corrected_K'])  # extinction corrected Ks

spectra[np.abs(spectra)>2] = 1.
idx = (fakemag>0)

#training
bcnn_net = ApogeeBCNN()
bcnn_net.max_epochs = 35
bcnn_net.callbacks = ErrorOnNaN()
bcnn_net.input_norm_mode = 3  # center label but not scale pixel
bcnn_net.labels_norm_mode = 4  # scale label but not center it
bcnn_net._last_layer_activation = "softplus"
bcnn_net.targetname = ['Ks-band fakemag']
bcnn_net.num_hidden = [192, 64]
bcnn_net.batch_size = 256
bcnn_net.metrics = [mean_absolute_error, mean_error, mean_absolute_percentage_error]
bcnn_net.fit(spectra[idx], np.expand_dims(fakemag[idx], axis=1), labels_err=np.expand_dims(fakemag_err[idx], axis=1),)
bcnn_net.save("astroNN_gaia_dr17_modell")

## Age

In [None]:
import h5py
import numpy as np
import astropy.units as u
from astropy.io import fits
from astropy.table import Table, vstack
from astropy.coordinates import SkyCoord
from astroNN.apogee import allstar
from tqdm.notebook import tqdm
from astroquery.vizier import Vizier
from astroquery.simbad import Simbad
from astroNN.datasets import xmatch
from astroNN.models import ApogeeBCNN
from astroNN.nn.losses import mean_absolute_error, mean_error, mean_absolute_percentage_error

allstar_f = fits.getdata(allstar(dr=17))
ra = allstar_f["ra"]
dec = allstar_f["dec"]

ra[0] = 0
dec[0] = 0

# load catalogs
apokasc3 = fits.getdata("APOKASC_cat_v6.6.1.fits.zip")
good_ages = ((apokasc3["APOKASC2_AGE"] != -9999.) & (apokasc3["APOKASC2_AGE_MERR"]/apokasc3["APOKASC2_AGE"] < 0.5))
apokasc3 = apokasc3[good_ages]

f_age_low_M = fits.getdata("kepler_low_metallicity_with_samples.fits")  # ask Ted Mackereth for the file somewhere

idx_1, idx_2, sep = xmatch(apokasc3["RA"], apokasc3["DEC"], ra, dec)
idx_3, idx_4, sep = xmatch(f_age_low_M["RA"], f_age_low_M["DEC"], ra, dec)


idx_combined, unique_indices = np.unique(np.concatenate([idx_4, idx_2]), return_index=True)
contspec = fits.getdata("contspec_dr17_synspec.fits")
input_spec = contspec[idx_combined]
input_spec[np.abs(input_spec)>2] = 1.

all_age = np.concatenate([f_age_low_M["Age_med "]/1e9, apokasc3['APOKASC2_AGE']])[unique_indices]
all_age_err = np.concatenate([f_age_low_M["Age_Sd "]/1e9, apokasc3['APOKASC2_AGE_MERR']])[unique_indices]
all_mass = np.concatenate([f_age_low_M["Mass_med "], apokasc3['APOKASC2_MASS']])[unique_indices]
all_mass_err = np.concatenate([f_age_low_M["Mass_Sd "], apokasc3['APOKASC2_MASS_RANERR']])[unique_indices]

agemass = np.stack([all_age, all_mass]).T
agemass_err = np.stack([all_age_err, all_mass_err]).T

bcnn_net = ApogeeBCNN()
bcnn_net.max_epochs = 60
bcnn_net.input_norm_mode = 3  # center label but not scale pixel
bcnn_net.labels_norm_mode = 4  # scale label but not center it
bcnn_net._last_layer_activation = "softplus"
bcnn_net.targetname = ['age', 'mass']
bcnn_net.num_hidden = [128,64]
bcnn_net.batch_size = 32
bcnn_net.reduce_lr_patience = 4
bcnn_net.metrics = [mean_absolute_error, mean_error, mean_absolute_percentage_error]
bcnn_net.fit(input_spec, agemass, labels_err=agemass_err)
bcnn_net.save("APOKASC2_BCNN_age_combined_dr17)