## This notebook describes the adversial training on dealing with unknown Gaia DR2 offset

## Incomplete notebook, work in progress

## First train a NN with Gaia DR2 parallax

In [None]:
from astroNN.models import ApogeeBCNN
from astroNN.nn.callbacks import ErrorOnNaN
from astroNN.gaia import mag_to_fakemag
from astroNN.nn.metrics import mean_absolute_error, mean_error, mean_absolute_percentage_error
import h5py
import numpy as np

with h5py.File('gaia_dr2_train.h5', 'r') as F:  # ensure the file will be cleaned up
    spectra = np.array(F['spectra'])
    fakemag = np.array(F['fakemag'])  # fakemag from Gaia+0.05mas
    fakemag_err = np.array(F['fakemag_err'])
    
bcnn_net = ApogeeBCNN()
bcnn_net.callbacks = ErrorOnNaN()
bcnn_net.num_hidden = [192, 96]
bcnn_net.max_epochs = 40
bcnn_net.metrics = [mean_absolute_error, mean_error, mean_absolute_percentage_error]
bcnn_net.autosave = True
bcnn_net.train(spectra, fakemag, labels_err=fakemag_err)

## Then train a simple n-deg polynomial regression model on parallax

In [None]:
from astroNN.gaia import gaiadr2_parallax
from astroNN.models import SimplePolyNN
import numpy as np

ra, dec, parallax, parallac_err = gaiadr2_parallax(cuts=1., keepdims=False, offset=0.)

gaia_nn = SimplePolyNN()
gaia_nn.max_epochs = 2
gaia_nn.num_hidden = 0  # degree of polynomial
gaia_nn.l2 = 0
gaia_nn.train( np.atleast_2d(parallax).T, np.atleast_2d(parallax).T)
print("n-deg poly weights: ", gaia_nn.get_weights()[0][0][0])
gaia_nn.save(# model name here)

## The adversarial training

In [1]:
import h5py
import numpy as np
import keras.backend as K
from tensorflow import Graph, Session
from astroNN.gaia import gaiadr2_parallax, mag_to_fakemag, extinction_correction, fakemag_to_pc
from astroNN.models import load_folder

lr = 0.0001

apogee_nn = load_folder("astroNN_Ks_fakemag")
with apogee_nn.graph.as_default():
    with apogee_nn.session.as_default():
        apogee_nn.max_epochs = 1
        apogee_nn.mc_num = 10  # smaller mc feedforward pass for batter performance

gaia_nn = load_folder("gaia_0deg_0823")
with gaia_nn.graph.as_default():
    with gaia_nn.session.as_default():
        gaia_nn.max_epochs = 1

with h5py.File('gaia_dr2_train.h5') as F:  # ensure the file will be cleaned up
    parallax = np.array(F['parallax'])
    # only train on star further than 0.1kpc away
    far_away = [parallax<10.]
    parallax = parallax[far_away]
    parallax_error = np.array(F['parallax_err'])[far_away]
    spectra = np.array(F['spectra'])[far_away]
    Kcorr = np.array(F['corrected_K'])[far_away]
    

counter = 0

for epoch in range(15):
    with apogee_nn.graph.as_default():
        with apogee_nn.session.as_default():
            K.set_value(apogee_nn.keras_model.optimizer.lr, lr)

    with gaia_nn.graph.as_default():
        with gaia_nn.session.as_default():
            K.set_value(gaia_nn.keras_model.optimizer.lr, lr/2)
            
    counter += 1
    if counter > 5:
        # schedule learning rate decay manually
        lr /= 2
        counter = 0  # reset coutner
    
    print("===========================================")
    with gaia_nn.graph.as_default():
        with gaia_nn.session.as_default():
            parallax_gen = gaia_nn.test(np.expand_dims(parallax, axis=1))
            print(f"Current Global Offset: {gaia_nn.get_weights()[0][0][0][0]:.{4}f}mas")
    fakemag_gen, fakemag_gen_err = mag_to_fakemag(Kcorr, parallax_gen[:, 0], parallax_error)

    # Train the discriminator
    with apogee_nn.graph.as_default():
        with apogee_nn.session.as_default():
            d_loss_real = apogee_nn.train(spectra, fakemag_gen, labels_err=fakemag_gen_err)
            fake_gen, fake_gen_err = apogee_nn.test(spectra)
    para_teach, para_teach_err = fakemag_to_pc(fake_gen[:, 0], Kcorr, fake_gen_err['total'][:, 0])
    para_teach, para_teach_err = 1000 / para_teach.value, 1000 / para_teach_err.value

    # Train the generator (to have the discriminator label samples as valid)
    with gaia_nn.graph.as_default():
        with gaia_nn.session.as_default():
            g_loss = gaia_nn.train(parallax, para_teach)

# saving model
with apogee_nn.graph.as_default():
    with apogee_nn.session.as_default():
        apogee_nn.save("astroNN_Ks_fakemag_adversial")

Using TensorFlow backend.


Loaded astroNN model, model type: Bayesian Convolutional Neural Network -> ApogeeBCNN
Loaded astroNN model, model type: Convolutional Neural Network -> SimplePolyNN
Starting Inference
Completed Inference, 1.03s elapsed
Current Global Offset: -0.0000mas
Number of Training Data: 44536, Number of Validation Data: 4948
Epoch 1/1
 - 12s - loss: -2.8568e+00 - output_loss: -2.8569e+00 - variance_output_loss: -2.8569e+00 - output_mean_absolute_error: 0.0529 - output_mean_error: -1.2797e-03 - val_loss: -2.8679e+00 - val_output_loss: -2.8679e+00 - val_variance_output_loss: -2.8679e+00 - val_output_mean_absolute_error: 0.0530 - val_output_mean_error: -1.3587e-02
Completed Training, 12.73s in total
Starting Dropout Variational Inference
Completed Dropout Variational Inference with 10 forward passes, 19.23s elapsed
Number of Training Data: 44536, Number of Validation Data: 4948
Epoch 1/1
 - 2s - loss: 0.7862 - mean_absolute_error: 0.1608 - mean_error: 0.0350 - val_loss: 0.5321 - val_mean_absolute_e

Epoch 1/1
 - 10s - loss: -2.8854e+00 - output_loss: -2.8854e+00 - variance_output_loss: -2.8854e+00 - output_mean_absolute_error: 0.0518 - output_mean_error: 5.9763e-04 - val_loss: -2.9007e+00 - val_output_loss: -2.9007e+00 - val_variance_output_loss: -2.9007e+00 - val_output_mean_absolute_error: 0.0511 - val_output_mean_error: 0.0058
Completed Training, 10.95s in total
Starting Dropout Variational Inference
Completed Dropout Variational Inference with 10 forward passes, 19.23s elapsed
Number of Training Data: 44536, Number of Validation Data: 4948
Epoch 1/1
 - 2s - loss: 0.6611 - mean_absolute_error: 0.1571 - mean_error: 0.0040 - val_loss: 0.4562 - val_mean_absolute_error: 0.1677 - val_mean_error: 0.0064
Completed Training, 2.04s in total
Starting Inference
Completed Inference, 1.24s elapsed
Current Global Offset: 0.0503mas
Number of Training Data: 44536, Number of Validation Data: 4948
Epoch 1/1
 - 10s - loss: -2.8911e+00 - output_loss: -2.8912e+00 - variance_output_loss: -2.8912e+00