# Make predictions of stellar 'group' on CASSIS spectra

Trained on SWS Atlas data.

In [1]:
import glob

import pandas as pd
import tensorflow as tf

from tensorflow import keras
from swsnet.dataframe_utils import read_spectrum

In [2]:
def load_model(file_path):
    """Returns a keras model (compressed as .h5)."""
    try:
        model = keras.models.load_model(file_path)
    except Exception as e:
        raise e
        
    return model

## Load keras model

Stored as .h5 file.

In [3]:
model = load_model('sws_model_01.h5')
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 64)                23040     
_________________________________________________________________
dense_1 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_2 (Dense)              (None, 5)                 325       
Total params: 27,525
Trainable params: 27,525
Non-trainable params: 0
_________________________________________________________________


## Read in metadata (pd.DataFrame)

In [4]:
data_dir = '../../data/cassis/'
metadata_pickle = data_dir + 'metadata_step1_normalized.pkl'
meta = pd.read_pickle(metadata_pickle)

In [5]:
meta.head()

Unnamed: 0,aorkey,object_name,ra,dec,flux_units,file_path,data_ok
0,3539200,HBC 356,60.808893,25.880773,Jy,spectra_normalized/3539200_renorm.pkl,True
1,3539456,LkCa 1,63.309814,28.317139,Jy,spectra_normalized/3539456_renorm.pkl,True
2,3539712,04108+2803A,63.472431,28.18739,Jy,spectra_normalized/3539712_renorm.pkl,True
3,3539968,MHO-3,63.627656,28.084989,Jy,spectra_normalized/3539968_renorm.pkl,True
4,3540224,Hubble 4,64.696616,28.333013,Jy,spectra_normalized/3540224_renorm.pkl,True


# Perform predictions

In [40]:
def predict_group(spectrum):
    """Return the probabilities (from model) that source belongs to each group."""
    f = spectrum['flux'].values
    probabilities = model.predict(np.array([f]))
    return probabilities

In [70]:
results_list = []

# Iterate over all spectra.
for index, row in enumerate(meta.itertuples()):
    if index % 200 == 0:
        print(index)
    
    file_path = getattr(row, 'file_path')
    aorkey = getattr(row, 'aorkey')

    spectrum = read_spectrum(data_dir + file_path)
    probabilities = predict_group(spectrum)
    
    wrap = [index, aorkey, file_path, *list(*probabilities)]
    results_list.append(wrap)
    
print('Done.')

0
200
400
600
800
1000
1200
1400
1600
1800
2000
2200
2400
2600
2800
3000
3200
3400
3600
3800
4000
4200
4400
4600
4800
5000
5200
5400
5600
5800
6000
6200
6400
6600
Done.


In [71]:
results_list[0]

[0,
 3539200,
 'spectra_normalized/3539200_renorm.pkl',
 0.071427196,
 0.8612647,
 0.059757993,
 0.0058639063,
 0.001686169]

In [72]:
np.savetxt('results.txt', np.array(results_list), delimiter=',', fmt='%s',
           header='index, aorkey, file_path, PROBABILITIES (groups 0 - 4) shifted by one downwards.')