In [None]:
# This notebook is meant as a walkthrough of the steps I followed to use The Cannon to derive individual 
# abundances for Gaia RVS spectra. Warning: Not a perfect "plug and chug" example (more of a hodepodge of
# several different notebooks), so use with caution and please make improvements!

In [1]:
# Essentials 
import pandas as pd
import numpy as np
from astropy.io import fits
from astropy.table import Table
import glob
import gzip
import csv
from ast import literal_eval
import os
from IPython.utils import io

from TheCannon import apogee 
from TheCannon import dataset
from TheCannon import model
import astropy.io.fits as pyfits
import astropy.table as tbl

# Useful tool to look at RVS spectra quickly, but non-essential
from gaia_tools.load.spec import read_spec_internal
# Change this to where your Gaia data lives
os.environ['GAIA_TOOLS_DATA']='gaia_data/'
os.path.basename

# Plotting settings 
import matplotlib.pyplot as plt;
import matplotlib as mpl;
mpl.rcParams.update(mpl.rcParamsDefault)
import matplotlib.cm as cm;
import matplotlib.colors as colors
get_ipython().run_line_magic('matplotlib', 'inline')
cmap = mpl.rcParams['xtick.labelsize'] = 12;mpl.rcParams['ytick.labelsize'] = 12;mpl.rcParams['font.weight'] = 'medium';mpl.rcParams['axes.linewidth'] = 1.5;mpl.rcParams['xtick.major.width'] = 1.5;mpl.rcParams['xtick.minor.width'] = 0.75;mpl.rcParams['xtick.minor.visible'] = True;mpl.rcParams['ytick.major.width'] = 1.5;mpl.rcParams['ytick.minor.width'] = 0.75;mpl.rcParams['ytick.minor.visible'] = True
%config InlineBackend.figure_format = 'retina'

## Plotting

In [None]:
# Useful function for colorbars
# Source: https://joseph-long.com/writing/colorbars/
def colorbar(mappable):
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    import matplotlib.pyplot as plt
    last_axes = plt.gca()
    ax = mappable.axes
    fig = ax.figure
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="10%", pad=0.1)
    cbar = fig.colorbar(mappable, cax=cax)
    plt.sca(last_axes)
    # plt.subplots_adjust(wspace=.3)
    return cbar

## Cross-matching 

In [None]:
# Calculate the SNR of RVS spectra
# Source: Somewhere in The Cannon...
def SNR(flux, uncertainty):
    SNR = np.zeros(len(flux))
    for i in range(len(flux)):
        flux_ = np.fromstring(flux[i].strip("[]"), sep=',')
        flux_error_ = np.fromstring(uncertainty[i].strip("[]"), sep=',')
        SNR_ = np.nanmean(flux_)/np.nanmean(flux_error_)
        SNR[i] = SNR_
    return SNR

In [None]:
# Get RVS data from folder
gaia_rvs_files=glob.glob('gaia_data/*.gz')

In [None]:
# Get APOGEE data from .fits file
fits_file='apogee_data/allStar-dr17-synspec_rev1.fits'
dfapogee = Table.read(fits_file,format='fits')
names = [name for name in dfapogee.colnames if len(dfapogee[name].shape) <= 1]
dfapogee = dfapogee[names].to_pandas()

In [None]:
# Small merge test case using (1) RVS file
for path in gaia_rvs_files[:1]:
    print('Processing', path, '...')
    dfgaia = pd.read_csv(path, skiprows=84, names=['source_id','solution_id','ra','dec','flux','flux_error','combined_transits','combined_ccds','deblended_ccds'])

df_merged=pd.merge(dfgaia, dfapogee, 
         left_on='source_id',
        right_on='GAIAEDR3_SOURCE_ID')

In [None]:
# Also, gives us the column names for the next step
df_merged.columns

In [None]:
# Make a new empty dataframe with the columns from the test case
# Can do this in a more elegant way...
new_df_merged = pd.DataFrame(columns=df_merged.columns)

# Iterate through all the RVS files and append to the new dataframe -- will take a bit!
for path in gaia_rvs_files:
    dfgaia = pd.read_csv(path, skiprows=84, names=['source_id','solution_id','ra','dec','flux','flux_error','combined_transits','combined_ccds','deblended_ccds'])
    
    df_merged=pd.merge(dfgaia, dfapogee, 
         left_on='source_id',
        right_on='GAIAEDR3_SOURCE_ID')
    
    new_df_merged = pd.concat([df_merged, new_df_merged], ignore_index=True)

# Save the new dataframe so we don't have to do that again!
new_df_merged.to_csv('crossmatched_stars.csv')

In [None]:
# We can read in the new dataframe
df2 = pd.read_csv('crossmatched_stars.csv')

In [None]:
# But we don't have SNR in the Gaia data, so let's add that
df2_SNR = SNR(df2['flux'], df2['flux_error'])
other = pd.DataFrame({'gaia_SNR': df2_SNR}) # Be better and don't name things this way...
new = df2.join(other)

# Save a new dataframe with Gaia SNR
new.to_csv('crossmatched_stars_v3.csv')

## Define reference set

In [None]:
# Sample 3D array (arr) by num
def sample(arr, num):
    Arr = arr
    new_3D_Arr = []
    for row in arr:
        new_row = row[::num]
        new_3D_Arr.append(new_row)
    return new_3D_Arr

In [None]:
# Defining our clean data set
starbad = 2**23 # "BAD overall for star: set if any of TEFF, LOGG, CHI2, COLORTE, ROTATION, SN error are set, or any parameter is near grid edge (GRIDEDGE_BAD is set in any PARAMFLAG)"
apo_clean = (np.bitwise_and(df2['ASPCAPFLAG'], starbad) == 0) & (df2['SNR'] > 200) & (df2['gaia_SNR'] > 100) & (df2['TEFF'] > 3500) & (df2['TEFF'] < 5500) & (df2['LOGG'] < 3.6)
indices = np.where(apo_clean)[0] # May need to tweak this if your data looks different
df2_clean = df2['source_id'][indices].dropna()

In [None]:
# Defining our reference data set by even sampling
# Let's look at only (3) dimensions first
LOGG_ = list(df2['LOGG'][indices])
TEFF_ = list(df2['TEFF'][indices])
FEH_ = list(df2['FE_H'][indices])

arr_3D = np.array([LOGG_, TEFF_, FEH_])

num = 10
test_sample = sample(arr_3D, num)

In [None]:
# Make sure this looks how you would expect...
new_LOGG = test_sample[0]
new_TEFF = test_sample[1]
new_FEH = test_sample[2]

fig, axs = plt.subplots(1, 2, figsize=(10,4.6))
axs[0].scatter(TEFF_, LOGG_, c=FEH_, s=5, cmap='coolwarm')
im = axs[1].scatter(new_TEFF, new_LOGG, c=new_FEH, s=5, cmap='coolwarm')

for i in range(len(axs)):
    axs[i].set_xlabel('T$_{eff}$', size=12)
    axs[i].set_ylabel('logg', size=12)
    axs[i].set_xlim(5600, 3400)
    axs[i].grid()

axs[0].set_title('Before Sampling', size=12)
axs[0].text(4200,3.4,"12,305 stars", size=12)
axs[1].set_title('After Sampling', size=12)
axs[1].text(4200,3.4,"1,231 stars", size=12)

cbar = colorbar(im)
cbar.set_label('[Fe/H]', size=12)
plt.tight_layout(h_pad=1)

In [None]:
# If this seems reasonable, then you can proceed with this
new_idx = indices[::num]

In [None]:
# Labels we want evenly sampled (all)
f_LOGG = df2['LOGG'][new_idx]
f_TEFF = df2['TEFF'][new_idx]
f_FE_H = df2['FE_H'][new_idx]
f_MG_FE = df2['MG_FE'][new_idx]
f_SI_FE = df2['SI_FE'][new_idx]
f_CA_FE = df2['CA_FE'][new_idx]
f_NI_FE = df2['NI_FE'][new_idx]
f_APOGEE_ID = df2['APOGEE_ID'][new_idx]

In [None]:
# Let's put them into a nice dataframe
ref_labels = pd.DataFrame({'APOGEE_ID': f_APOGEE_ID, 'Teff': f_TEFF, 'logg': f_LOGG, '[Fe/H]': f_FE_H, '[Mg/Fe]': f_MG_FE, '[Si/Fe]': f_SI_FE, '[Ca/Fe]': f_CA_FE, '[Ni/Fe]': f_NI_FE})

In [None]:
# Save these! This will be our reference set
ref_labels.to_csv('ref_labels.csv')

In [None]:
# We will need the reference IDs, but let's fix that they are strings!
fix = df2['APOGEE_ID'][new_idx].astype("string")
fix_ids = fix.str[2:-1]
fix_ids.to_csv('ref_ids.csv', header=False, index=False)

## Training 

In [2]:
# Helpful defintions (modified)

# Modified for this particular case
def load_rvs_spec(source_ids, assume_unique=False):
    """
    NAME:
        load_rvs_spec
    PURPOSE:
        Read corresponding RVS spectra for a list of source id
    INPUT:
        source_ids (int, list, ndarray): source id
        assume_unique (bool): whether to assume the list of source id is unique
    OUTPUT:
        wavelength grid, RVS spectra flux row matched to source_id, RVS spectra corresponding flux uncertainty
    HISTORY:
        2022-06-16 - Written - Henry Leung (UofT)
    """
    base_path = 'gaia_data/'
    wavelength_grid = np.arange(846, 870.01, 0.01)
    return read_spec_internal(
        source_ids=source_ids,
        assume_unique=assume_unique,
        base_path=base_path,
        wavelength_grid=wavelength_grid,
    )

# Modified for this particular case
def load_spectra(data_dir):
    """ Reads wavelength, flux, and flux uncertainty data from apogee fits files

    Parameters
    ----------
    data_dir: str
        Name of the directory containing all of the data files

    Returns
    -------
    wl: ndarray
        Rest-frame wavelength vector

    fluxes: ndarray
        Flux data values

    ivars: ndarray
        Inverse variance values corresponding to flux values
    """
    print("Loading spectra from directory %s" %data_dir)
    files = list(sorted([data_dir + "/" + filename
             for filename in os.listdir(data_dir) if filename.endswith('fits')]))
    nstars = len(files)  
    for jj, fits_file in enumerate(files):
        file_in = pyfits.open(fits_file)
        flux = np.array(file_in[1].data)
        if jj == 0:
            npixels = len(flux)
            fluxes = np.zeros((nstars, npixels), dtype=float)
            ivars = np.zeros(fluxes.shape, dtype=float)
            start_wl = file_in[1].header['CRVAL1']
            diff_wl = file_in[1].header['CDELT1']
            val = diff_wl * (npixels) + start_wl
            wl_full_log = np.arange(start_wl,val, diff_wl)
            wl_full = [10 ** aval for aval in wl_full_log]
            wl = np.array(wl_full)
        flux_err = np.array((file_in[2].data))
        badpix = apogee.get_pixmask(flux, flux_err)
        ivar = np.zeros(npixels)
        ivar[~badpix] = 1. / flux_err[~badpix]**2
        fluxes[jj,:] = flux
        ivars[jj,:] = ivar
    # Convert filenames to actual IDs
    names = np.array([f.split('dr17-')[1].split('.fits')[0] for f in files])
    print("Spectra loaded")
    # Make sure they are numpy arrays
    return np.array(names), np.array(wl), np.array(fluxes), np.array(ivars) 


# Modified to include the relevant label names for this case
def load_labels(filename):
    """ Extracts reference labels from a file

    Parameters
    ----------
    filename: str
        Name of the file containing the table of reference labels

    ids: array
        The IDs of stars to retrieve labels for

    Returns
    -------
    labels: ndarray
        Reference label values for all reference objects
    """
    print("Loading reference labels from file %s" %filename)
    data = pd.read_csv(filename)
    ids = data['ID']
    inds = ids.argsort()
    ids = ids[inds]
    teff = data['Teff'] 
    teff = teff[inds]
    logg = data['logg']
    logg = logg[inds]
    mh = data['[Fe/H]']
    mh = mh[inds]
    mgfe = data['[Mg/Fe]']
    mgfe = mgfe[inds]
    sife = data['[Si/Fe]']
    sife = sife[inds]
    cafe = data['[Ca/Fe]']
    cafe = cafe[inds]
    nife = data['[Ni/Fe]']
    nife = nife[inds]
    return np.vstack((teff,logg,mh,mgfe,sife,cafe,nife)).T 

# Find inverse variance values corresponding to flux values
# Source: https://github.com/annayqho/TheCannon/blob/8010a0a5dc9a3f9bb91efa79d7756f79b3c7ba9a/TheCannon/apogee.py
def find_ivar(flux, flux_err):
    npixels = len(flux)
    badpix = apogee.get_pixmask(flux, flux_err)
    ivar = np.zeros(npixels)
    ivar[~badpix] = 1. / flux_err[~badpix]**2
    return ivar

# Convert string fluxes and errors to integers
def pre_processing(df):
    df[['flux', 'flux_error']] = df[['flux', 'flux_error']].astype('string')
    df['source_id'] = df['source_id'].astype('int64')
    return df

# This automates The Cannon steps
def auto_cannon(chunked_files, count):    
    # Smaller dataframe
    dfgaia = pd.DataFrame(columns={'source_id', 'flux', 'flux_error'})
    dfgaia = pre_processing(dfgaia)

    for path in chunked_files:
        temp_gaia = pd.read_csv(path, skiprows=84, names=['source_id','solution_id','ra','dec','flux','flux_error','combined_transits','combined_ccds','deblended_ccds'], 
                                usecols = ['source_id','flux','flux_error'], dtype = {'source_id':'int64', 'flux':'string','flux_error':'string'}, low_memory = True)
        dfgaia = pd.concat([temp_gaia, dfgaia], ignore_index=True)   
    
    arr_flux = []
    arr_flux_error = []
    
    # Need to convert strings to arrays    
    for i in range(len(dfgaia['flux'])): 
        arr_flux.append(np.fromstring(dfgaia['flux'][i].strip("[]"), sep=',')) 
        arr_flux_error.append(np.fromstring(dfgaia['flux_error'][i].strip("[]"), sep=','))

    # Define test set
    test_flux = np.nan_to_num(np.array(arr_flux))
    test_errs = np.nan_to_num(np.array(arr_flux_error), nan=999)
    test_ID = dfgaia['source_id']
    test_ivar = []

    for i in range(len(test_flux)):
        ivar = find_ivar(test_flux[i], test_errs[i])
        test_ivar.append(ivar)

    test_ivar = np.array(test_ivar)
    
    ds = dataset.Dataset(wl, tr_ID, tr_flux, tr_ivar, tr_label, test_ID, test_flux, test_ivar)
    ds.set_label_names(['T_{eff}', '\log g', '[Fe/H]', '[Mg/Fe]', '[Si/Fe]', '[Ca/Fe]', '[Ni/Fe]'])
    
    md = model.CannonModel(2, useErrors=False)
    md.fit(ds)
    
    label_errs = md.infer_labels(ds)
    test_labels = ds.test_label_vals
    
    opt_err = label_errs[0]
    chisqs = label_errs[1]
    
    snr = ds.test_SNR
    
    model_teff = [i[0] for i in test_labels]
    model_logg = [i[1] for i in test_labels]
    model_feh = [i[2] for i in test_labels]
    model_mgfe = [i[3] for i in test_labels]
    model_sife = [i[4] for i in test_labels]
    model_cafe = [i[5] for i in test_labels]
    model_nife = [i[6] for i in test_labels]
    
    err_teff = [i[0] for i in opt_err]
    err_logg = [i[1] for i in opt_err]
    err_feh = [i[2] for i in opt_err]
    err_mgfe = [i[3] for i in opt_err]
    err_sife = [i[4] for i in opt_err]
    err_cafe = [i[5] for i in opt_err]
    err_nife = [i[6] for i in opt_err]
    
    infer_spec = md.infer_spectra(ds)
    model_spec = md.model_spectra
    
    
    model_df = pd.DataFrame({'source_id': test_ID, 'model_teff': model_teff, 
                            'model_logg': model_logg, 'model_feh': model_feh, 'model_mgfe': model_mgfe,
                            'model_sife': model_sife, 'model_cafe': model_cafe, 'model_nife': model_nife,
                            'err_teff': err_teff, 'err_logg': err_logg, 'err_feh': err_feh, 
                            'err_mgfe': err_mgfe, 'err_sife': err_sife, 'err_cafe': err_cafe,
                            'err_nife': err_nife, 'SNR': snr, 'chisqs': chisqs})
    model_df.to_hdf('derived_parameters/gaiarvs_params_chunk_'+str(count)+'.h5', key='df')
    
    return

In [None]:
# Training set steps

# Read-in crossmatched data
df = pd.read_csv('crossmatched_stars_v3.csv')

# We only need wavelength
# Can do this in a more elegant way...
wavelength, flux, flux_err = load_rvs_spec(6772307906968335104)

# Read-in IDs for reference stars in crossmatched data
df1 = pd.read_csv("/users/carterco/2nd_yr_project/ref_ids.csv",header=None, names=['IDs']).astype("string")

# All IDs in crossmatched data
ap_id = df['APOGEE_ID']
ap_id = list(ap_id.str[2:-1])

# Find reference stars in crossmatched data
order = []

for i, j in enumerate(df1['IDs']):
    try:
        order.append(ap_id.index(j))
    except:
        pass

# Read in reference flux and errors
df2 = pd.read_csv('gaia_fluxes.csv', header=None)
df3 = pd.read_csv('gaia_flux_errs.csv', header=None)

# Clean up reference fluxes and errors
df2 = df2.fillna(0)
df3 = df3.replace(np.nan, 999)

# Make sure they are arrays
fluxes = df2.to_numpy()
flux_errs = df3.to_numpy()

# Calculate ivars
all_ivars = []

for i in range(len(fluxes)):
    ivar = find_ivar(fluxes[i], flux_errs[i])
    all_ivars.append(ivar)
    
all_ivars = np.array(all_ivars)

# All reference stars
all_IDs = np.array(df['APOGEE_ID'].str[2:-1][order])
all_labels = load_labels("ref_labels.csv")
wl = wavelength

# All reference stars to training set
tr_flux = fluxes
tr_ID = all_IDs
tr_label = all_labels
tr_ivar = all_ivars

In [None]:
# Folder with gaia RVS files:
gaia_rvs_files = glob.glob('gaia_data/*.gz')

# Split data into digestible chunks
chunks = []
I = np.arange(0, 3300, 100)

for i, j in enumerate(I):
    try:
        hold = gaia_rvs_files[I[i]:I[i+1]]
        chunks.append(hold)
    except:
        hold = gaia_rvs_files[I[i]:len(gaia_rvs_files)-1]
        chunks.append(hold)

In [None]:
# Run The Cannon for each data chunk and keep track of which "chunk"
counter = 0
for chunk in chunks:
    counter+=1
    try:
        auto_cannon(chunk, counter)
        print("Done with chunk "+str(counter))
    except:
        print("Failed at chunk "+str(counter))

In [None]:
# Should end up with a folder of files with derived labels, which you can append as you'd like!