In [None]:
'''

code to retrieve CO slabs/profiles of near-IR NIRSPEC observations. Use slabspec initial fits as first guesses for retrievals
For now, retrieving the inner and outer temp and column density and emitting area (5 variables total), and assumes that they decay as a power law
assumes r_in = 0.1 amd r_out = 10.0 (also for now).
!! marks to-do items or hard-coded things

'''

import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
import astropy.units as u
import pickle as pickle
import os as os
import pandas as pd
from astropy.table import Table, vstack, QTable
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import sys
import iris as iris
from iris import setup
import time
from utils import * # !! are there different utils i need here?

# retrievals:
from dynesty import NestedSampler
import scipy.stats as stats
import pickle
import dynesty.plotting as dyplot
import dynesty.utils as dyfunc
from corner import corner

# Queries - - - - - - - - - - - - - - - - - - - - - - - - - - -
radial_profile = 0 # if 0: will retrieve a slab model w/ a single component. if 1: will retrieve a radial profile as defined below
forward_model_query = 0 # if 0: will run full retrieval. if 1: will only generate forward model using initial inputs
keplerian_query = 1 # if 0: will use normal slab.simulate. if 1: will apply keplerian broadening
CO_12 = 1 # if 1: will model 12CO. if 0: will not
CO_13 = 0 # if 1: will model 13CO. if 0: will not

corner_plot_query = 1 # 1, will make corner plot of retrieval results; 0, will not
plot_results_query = 1 # 1, will plot measured and fit fluxes; 0, will not

if forward_model_query == 1:
    corner_plot_query = 0 # won't plot retrieval-assessing plots if only doing a forward model

# !! need a conditional here for dimensions
ndim = 5 # dimensions; number of free parameters (number of free parameters, called in NestedSampler) !! make this automatic?
dlogz = 0.1  # termination condition (called in NestedSampler)

cont_jy = 0.95 # continuum level in Jy. Should be ~1 for most nirspec data??
R = 37500 # for NIRSPEC

# filenames and paths - - - - - - - - - - - - - - - - - - - - - - - - - - -
source_name = 'FZTau'
input_file = np.loadtxt('./'+source_name+'_output_0', dtype=str)
name_for_files = 'retrieval_nirspec_test_'+source_name
checkpoint_file = '/home/dahlek/prima/retrievals/pickle_files/'+name_for_files+'.save' # file to use as 
figure_directory = '/home/dahlek/prima/retrievals/figures/nirspec/' # save figures here
pickle_file = checkpoint_file[:-5]
figure_name = np.copy(name_for_files)
path_to_moldata = '/home/dahlek/HITRAN_data'
path_to_data = '/home/dahlek/spectra/'+source_name+'_fullstack.csv'

# load data and info - - - - - - - - - - 
# Load a continuum-subtracted spectrum
infile=Table.read(path_to_data, format='ascii')
df=infile.to_pandas()
wavelength = np.array(df['wave']) # obs_wgrid
flux = np.array(df['flux']-cont_jy) # subtract continuum? seems to be normalized to 1
error = np.array(df['uflux'])
wmin = np.min(wavelength); wmax = np.max(wavelength)

wl, flux_ignore, flux_error, wl_min, wl_max, wl_min_intermediate, wl_max_intermediate = readin(path_to_data)

# make a flux and wavelength array that ignores the nan values present in original dataset
no_nan_locations = np.where(np.isnan(flux) == False)[0]
flux_without_nans = flux[no_nan_locations]; wavelength_without_nans = wavelength[no_nan_locations]; error_without_nans = error[no_nan_locations]
flux = np.copy(flux_without_nans); wavelength = np.copy(wavelength_without_nans); error = np.copy(error_without_nans)

fine_wgrid = np.arange(wmin-0.1 , wmax+0.1, 1e-5) # define a wavelength grid


# inputs from slabspec and variables specific to disk - - - - - - - - - - -

# if CO_12 == 1:
# if 12CO is turned out, load input file as if it were 12CO data
# if we run simultaneous retrievals w/ different species later, will want to label these w/ CO specifically. for now, we're doing one at a time so leave the variable names more generic
source_name = input_file[0]
distance = float(input_file[1]) # pc
inc = float(input_file[2]) # deg
M_star = float(input_file[3]) # M_solar
# high J values
logN_high, logN_perr_high, logN_nerr_high, T_high, T_perr_high, T_nerr_high, logOmega_high, logOmega_perr_high, logOmega_nerr_high = input_file[4:13]
logN_low, logN_perr_low, logN_nerr_low, T_low, T_perr_low, T_nerr_low, logOmega_low, logOmega_perr_low, logOmega_nerr_low = input_file[13:22]
logN_all, logN_perr_all, logN_nerr_all, T_all, T_perr_all, T_nerr_all, logOmega_all, logOmega_perr_all, logOmega_nerr_all = input_file[22:31]
# rename and organize to work with preexisting iris script
log_T_slabspec_high = np.log10(T_high); log_T_slabspec_low = np.log10(T_low); log_T_slabspec_all = np.log10(T_all)
log_N_slabspec_high = logN_high-4; log_N_slabspec_low = logN_low-4; log_N_slabspec_all = logN_all-4 # !! subtracting 4 is converting m^-2 to cm^-2 in log space
# convert solid angle in radians to square AU
omega_high = 10**logOmega_high; omega_low = 10**logOmega_low; omega_all = 10**logOmega_all
log_area_au_high = np.log10(calc_radius(omega_high, distance)**2); log_area_au_low = np.log10(calc_radius(omega_low, distance)**2); log_area_au_all = np.log10(calc_radius(omega_all, distance)**2); 

dV_slabspec_CO = 2 # want to sample the line well enough
dV_slabspec_13CO = 2

# molecule information: - - - - - - - - - - - - - - - - - - - - - - - - - - -
# can add more conditionals here for other molecules
if CO_12 == 1 and CO_13 == 0:
    list_of_molecules = ['CO']
    #setup.setup_linelists('CO', 'CO', 1, path_to_moldata)
elif CO_12 == 1 and CO_13 == 1:
    list_of_molecules = ['CO','13CO']
    #setup.setup_linelists('CO', 'CO', 1, path_to_moldata)
    #setup.setup_linelists('13CO', 'CO', 2, path_to_moldata)
elif CO_12 == 0 and CO_13 == 1:
    list_of_molecules = ['13CO']
    #setup.setup_linelists('13CO', 'CO', 2, path_to_moldata)


# initiate slab - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
slab = iris.slab(molecules=list_of_molecules, wlow=wmin-0.5, whigh=wmax+0.5, path_to_moldata=path_to_moldata)

if keplerian_query == 0:
    def compiled_slab(distance, T_ex, N_mol, A_au, dV, fine_wgrid, wavelength, R):
        slab.setup_disk(distance, T_ex, N_mol, A_au, dV) # initialize object
        slab.setup_grid(fine_wgrid, wavelength, R) # set up wavelength and model parameters
        slab.simulate() # make model
        return slab.downsampled_flux

elif keplerian_query == 1: # output shape will need to be added as an input here, and everywhere else compiled_slab is called
    def compiled_slab(distance, T_ex, N_mol, A_au, dV, r_in, M_star, inc, fine_wgrid, wavelength, R):
        slab.setup_disk(distance, T_ex, N_mol, A_au, dV, inc, M_star, r_in) # initialize object
        slab.setup_grid(fine_wgrid, wavelength, R) # set up wavelength and model parameters
        slab.simulate_keplerian() # make model
        return slab.downsampled_flux

compiled_slab_jit = jax.jit(compiled_slab)


# define slab model, along with temp and column density profiles, etc. - - - - - - - - -

def N_powerlaw(r, N_in, N_out, r_in=0.1, r_out=10.0):
    '''
    Uses innermost and outermost column densities to calculate a power law

    Parameters
    -----------
    r : float
        radius value in AU. plug in rgrid here
    N_in, N_out : 
        column densities from high- and low-energy lines, respectively **not log**
    r_in and r_out : 
        inner and outer boundary of disk in AU
        
    Returns
    -----------
    k*r**a : array, float
        column density power law over radius values r
    '''   
    a = (np.log10(N_out)-np.log10(N_in))/(np.log10(r_out)-np.log10(r_in)) # a = (log(y2) - log(y1)) / (log(x2) - log(x1)) # exponent
    k = N_in/(r_in**a) # k = y1 / (x1^a) # factor

    return k*r**a

def T_powerlaw(r, T_in, T_out, r_in=0.1, r_out=10.0):
    '''
    Uses innermost and outermost column densities to calculate a power law

    Parameters
    -----------
    r : float
        radius value in AU. plug in rgrid here
    T_in, T_out : 
        temperatures from high- and low-energy lines, respectively **not log**
    r_in and r_out : 
        inner and outer boundary of disk in AU
        
    Returns
    -----------
    k*r**a : array, float
        temp power law over radius values r
    '''
    a = (np.log10(T_out)-np.log10(T_in))/(np.log10(r_out)-np.log10(r_in)) # a = (log(y2) - log(y1)) / (log(x2) - log(x1)) # exponent
    k = T_in/(r_in**a) # k = y1 / (x1^a) # factor

    return k*r**a

def make_annuli(npoints, rin, rout, inc):
    '''
    Input: the number of radial points, innermost radius (au), outermost radius (au), and disk inclination (deg)
    Output the grid of AREAS of each annulus (au2), and the grid of RADII (au) at the center of each annulus
    '''
    rlow = np.logspace(np.log10(rin), np.log10(rout), npoints+1)
    dr = []
    for i,r in enumerate(rlow):
      try:
        dr.append(rlow[i+1] - rlow[i])
      except:
        continue

    dr = jnp.array(dr) # r bin spacing
    rgrid = rlow[0:npoints] + dr/2 # r where T,N will be calculated

    Agrid = jnp.pi * np.cos(inc*np.pi/180) * (((rlow[0:npoints]+dr)**2) - (rlow[0:npoints]**2))

    return Agrid, rgrid


# calculate values to make a forward model
T_in = 10**log_T_slabspec_high ; T_out = 10**log_T_slabspec_low
N_in = 10**log_N_slabspec_high ; N_out = 10**log_N_slabspec_low

# Excitation temperatures for each molecule in K
T_ex = jnp.array([  np.array([CO_12*T_powerlaw(j, T_in, T_out) for j in rgrid]),
                    np.array([CO_13*T_powerlaw(j, T_in, T_out) for j in rgrid]) ]) # will need to replace N_powerlaw to something specific to 13CO
# column densities in cm^-2
N_mol = jnp.array([ np.array([CO_12*N_powerlaw(j, N_in, N_out) for j in rgrid]),
                    np.array([CO_13*N_powerlaw(j, N_in, N_out) for j in rgrid]) ]) # will need to replace N_powerlaw to something specific to 13CO
# emitting areas in au^2
A_au =  np.array([ np.array([CO_12*10**log_area_au_all]),
                   np.array([CO_13*10**log_area_au_all]) ])
# fixed line widths in km/s (line FWHM)
dV = np.array([ np.array([CO_12*dV_slabspec_CO]),
                np.array([CO_13*dV_slabspec_13CO]) ])


if keplerian_query == 0:
    model = compiled_slab_jit(distance, T_ex, N_mol, A_au, dV, fine_wgrid, wavelength, R)
elif keplerian_query == 1:
    model = compiled_slab_jit(distance, T_ex, N_mol, A_au, dV, r_in, M_star, inc, fine_wgrid, wavelength, R)

if len(model) != len(flux):
    print('Warning!!! mismatch between number of model and flux data points.')


# save the wavelength, measured flux, and modeled flux if you want to
'''
dummy = np.zeros((len(wavelength),3))
dummy[:,0] = wavelength
dummy[:,1] = flux
dummy[:,2] = model
np.savetxt('./forward_model_'+name_for_files, dummy)
'''

if plot_results_query == 1:
    plt.plot(wavelength, flux, color='black', lw=0.5, label='data')
    plt.fill_between(wavelength, model,  color='dodgerblue', alpha=0.6, label='model')

    plt.xlim(wavelength.min()-0.01, wavelength.max()+0.01)
    plt.ylabel('Flux density (Jy)')
    plt.xlabel('Wavelength ($\\mu$m)')
    plt.title('Forward model guess w/ data')

    #plt.ylim(np.min(flux)-1e1, np.max(flux)+1e1 )
    plt.legend()
    plt.savefig(figure_directory+'initial_guess_model_spectrum_'+figure_name)
    plt.show()

if forward_model_query == 1:
    sys.exit() # end the program if just running a forward model


# begin setting up retrieval - - - - - - - - - - - - - - - - - - - - - - - - - -

# set up priors based on slabspec inputs
# !! commenting out 13CO stuff for now

# define test range for high N 
min_N_guess_high = log_N_slabspec_high-3.0
max_N_guess_high = log_N_slabspec_high+3.0

# define test range for low N 
min_N_guess_low = log_N_slabspec_low-3.0
max_N_guess_low = log_N_slabspec_low+3.0

# define test range for high T
min_T_guess_high = log_T_slabspec_high-1.0
max_T_guess_high = log_T_slabspec_high+1.0

# define test range for low T 
min_T_guess_low = log_T_slabspec_low-1.0
max_T_guess_low = log_T_slabspec_low+1.0

#min_temp_guess_13CO = log_T_slabspec_13CO-1.0
#max_temp_guess_13CO = log_T_slabspec_13CO+1.0-min_temp_guess_13CO

min_area_guess = log_area_au_all-1.0
max_area_guess = log_area_au_all+1.0-min_area_guess # for uniform prior

#min_area_guess_13CO = log_area_au_13CO-1.0
#max_area_guess_13CO = log_area_au_13CO+1.0-min_area_guess_13CO

# !! need to add conditionals within ln_prior for modeling CO and/or 13CO. can I feed it values aside from uparams?
def ln_prior(uparams):

    '''
    Dynesty samples from a unit cube, so we need to transform this to the range
    of values we want for each free parameter.

    e.g. For log Area, we go from U(0,1) ---->  U(-3, 3.0)
         (so we are sampling area values from 0.001 to 1000.0)
    '''
    ulog_T_high, ulog_T_low, ulog_N_high, ulog_N_low, ulog_A = uparams 

    # Truncated Normal prior for log temperature, high energy
    m, s = log_T_slabspec_high, 0.5 # mean and standard deviation
    low, high = min_T_guess_high, max_T_guess_high  # lower and upper bounds
    low_n, high_n = (low - m) / s, (high - m) / s  # standardize
    log_T_high = stats.truncnorm.ppf(ulog_T_high, low_n, high_n, loc=m, scale=s)
    
    # Truncated Normal prior for log temperature, low energy
    m, s = log_T_slabspec_low, 0.5 # mean and standard deviation
    low, high = min_T_guess_low, max_T_guess_low  # lower and upper bounds
    low_n, high_n = (low - m) / s, (high - m) / s  # standardize
    log_T_low = stats.truncnorm.ppf(ulog_T_low, low_n, high_n, loc=m, scale=s)
    
    # Truncated Normal prior for log column density, high energy
    m, s = log_N_slabspec_high, 1.5 # mean and standard deviation
    low, high = min_N_guess_high, max_N_guess_high  # lower and upper bounds
    low_n, high_n = (low - m) / s, (high - m) / s  # standardize
    log_N_high = stats.truncnorm.ppf(ulog_N_high, low_n, high_n, loc=m, scale=s)
    
    # Truncated Normal prior for log column density, low energy
    m, s = log_N_slabspec_low, 1.5 # mean and standard deviation
    low, high = min_N_guess_low, max_N_guess_low  # lower and upper bounds
    low_n, high_n = (low - m) / s, (high - m) / s  # standardize
    log_N_low = stats.truncnorm.ppf(ulog_N_low, low_n, high_n, loc=m, scale=s)

    # Uniform prior for log area U(-3.0, 3.0) # !! hard-coded guess range for area
    log_A = min_area_guess + ulog_A * max_area_guess

    return log_T_high, log_T_low, log_N_high, log_N_low, log_A



# LOG LIKELIHOOD
def ln_like(params, wavelength=wavelength, flux=flux, error=error, keplerian_query=keplerian_query, inc=inc, M_star=M_star, CO_12=CO_12, CO_13=CO_13):
    log_T_high, log_T_low, log_N_high, log_N_low, log_A = params

    '''Set up the disk model'''
    T_in = 10**log_T_high ; T_out = 10**log_T_low
    N_in = 10**log_N_high ; N_out = 10**log_N_low
    
    # Excitation temperatures for each molecule in K
    T_ex = jnp.array([  np.array([CO_12*T_powerlaw(j, T_in, T_out) for j in rgrid]),
                        np.array([CO_13*T_powerlaw(j, T_in, T_out) for j in rgrid]) ]) # will need to replace T_powerlaw to something specific to 13CO
    # column densities in cm^-2
    N_mol = jnp.array([ np.array([CO_12*N_powerlaw(j, N_in, N_out) for j in rgrid]),
                        np.array([CO_13*N_powerlaw(j, N_in, N_out) for j in rgrid]) ]) # will need to replace N_powerlaw to something specific to 13CO
    # emitting areas in au^2
    A_au =  np.array([ np.array([CO_12*10**log_area_au_all]),
                       np.array([CO_13*10**log_area_au_all]) ])
    # fixed line widths in km/s (line FWHM)
    dV = np.array([ np.array([CO_12*dV_slabspec_CO]),
                    np.array([CO_13*dV_slabspec_13CO]) ])

    # can add print statement for variables here
    if keplerian_query == 0:
        model = compiled_slab_jit(distance, T_ex, N_mol, A_au, dV, fine_wgrid, wavelength, R)
    elif keplerian_query == 1:
        model = compiled_slab_jit(distance, T_ex, N_mol, A_au, dV, r_in, M_star, inc, fine_wgrid, wavelength, R)
    print('log_T_high, log_T_low, log_N_high, log_N_low, log_A ', log_T_high, log_T_low, log_N_high, log_N_low, log_A )
    print('log liklihood',- np.nansum( 1/(2*(error**2)) * ((flux - model)**2) ))
    return - np.nansum( 1/(2*(error**2)) * ((flux - model)**2) )



# Set up the Nested Sampler - - - - - - - - - - - - - - - - - - - - - - - - 
print('Setting up the Nested Sampler...')
dsampler = NestedSampler(ln_like, ln_prior, ndim=ndim,
                             bound='multi', sample='rwalk', nlive=20*ndim)

# Run retrieval - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
print('Running retrieval...')

dsampler.run_nested(dlogz=dlogz, maxcall=2e6,
                    checkpoint_file=checkpoint_file, checkpoint_every=600,
                    print_progress=True)

# Get the results, save - - - - - - - - - - - - - - - - - - - - - - 
dres = dsampler.results

with open(pickle_file, 'wb') as f:
    pickle.dump(dres, f)


# Get the WEIGHTED samples - - - - - - - - - - - - - - - - - - - - -
weights = dres.importance_weights()
wsamps = dyfunc.resample_equal(dres.samples, weights)

mean = np.mean(wsamps, axis=0)
std = np.std(wsamps, axis=0)

# pull T, N, A, and power laws for best-fit model: - - - - - - - - - - - - - - 

log_T_high, log_T_low, log_N_high, log_N_low, log_A = mean[0], mean[1], mean[2], mean[3], mean[4]
# add lines here for 13CO or other molecules

print('Best fit log T (high energy) =', round(mean[0],2), "+-", round(3*std[0], 4))
print('Best fit log T (low energy) =', round(mean[1],2), "+-", round(3*std[1], 4))
print('Best fit log N (high energy) =', round(mean[2],2), "+-", round(3*std[2], 4))
print('Best fit log N (low energy) =', round(mean[3],2), "+-", round(3*std[3], 4))
print('Best fit log A (12CO) =', round(mean[4],2), "+-", round(3*std[4], 4))

r_in = 0.1; r_out = 10
T_in = 10**log_T_high ; T_out = 10**log_T_low
N_in = 10**log_N_high ; N_out = 10**log_N_low
a_N = (np.log10(N_out)-np.log10(N_in))/(np.log10(r_out)-np.log10(r_in)) # a = (log(y2) - log(y1)) / (log(x2) - log(x1)) # exponent
k_N = N_in/(r_in**a) # k = y1 / (x1^a) # factor
print('Col density power law: N(r) =', str(k_N),'r^',a_N)
a_T = (np.log10(T_out)-np.log10(T_in))/(np.log10(r_out)-np.log10(r_in)) # a = (log(y2) - log(y1)) / (log(x2) - log(x1)) # exponent
k_T = T_in/(r_in**a) # k = y1 / (x1^a) # factor
print('Temperature power law: T(r) =', str(k_T),'r^',a_T)


# Excitation temperatures for each molecule in K
T_ex = jnp.array([  np.array([CO_12*T_powerlaw(j, T_in, T_out) for j in rgrid]),
                    np.array([CO_13*T_powerlaw(j, T_in, T_out) for j in rgrid]) ]) # will need to replace N_powerlaw to something specific to 13CO
# column densities in cm^-2
N_mol = jnp.array([ np.array([CO_12*N_powerlaw(j, N_in, N_out) for j in rgrid]),
                    np.array([CO_13*N_powerlaw(j, N_in, N_out) for j in rgrid]) ]) # will need to replace N_powerlaw to something specific to 13CO
# emitting areas in au^2
A_au =  np.array([ np.array([CO_12*10**log_A]),
                   np.array([CO_13*10**log_A]) ])
# fixed line widths in km/s (line FWHM)
dV = np.array([ np.array([CO_12*dV_slabspec_CO]),
                np.array([CO_13*dV_slabspec_13CO]) ])

# generate slab model using best-fit results
if keplerian_query == 0:
    bfmodel = compiled_slab_jit(distance, T_ex, N_mol, A_au, dV, fine_wgrid, wavelength, R)
elif keplerian_query == 1:
    bfmodel = compiled_slab_jit(distance, T_ex, N_mol, A_au, dV, r_in, M_star, inc, fine_wgrid, wavelength, R)


# plot results - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

if plot_results_query == 1:
    # !! need to break this into two subplots w/ intermediate wl breaking it up
    plt.figure(figsize=(17,3))
    plt.step(wavelength, flux, color='black', lw=0.5, label='data')
    plt.fill_between(wavelength, bfmodel, color='dodgerblue', alpha=0.6, label='best fit model')

    plt.xlim(wavelength.min(), wavelength.max())
    plt.ylabel('Flux (Jy)')
    plt.xlabel('Wavelength ($\\mu$m)')
    plt.title('Test retrieval - FZ Tau')

    #plt.ylim(-0.02, 0.2)
    plt.legend()
    plt.savefig(figure_directory+'model_spectrum_'+name_for_files)

if plot_results_query == 1:
    plt.figure(figsize=(8,4))
    plt.plot(rgrid, T_powerlaw(rgrid, T_in, T_out))
    plt.title('Temperature profile')
    plt.ylabel('Temperature (K)')
    plt.xlabel('Radius (AU)')
    plt.text(rgrid[-5], (T_in-T_out)/2, 'T(r) = '+str(k_T)+' r^'+str(a_T))
    plt.savefig(figure_directory+'temp_profile_'+name_for_files)

if plot_results_query == 1:
    plt.figure(figsize=(8,4))
    plt.plot(rgrid, N_powerlaw(rgrid, N_in, N_out))
    plt.title('Column density profile')
    plt.ylabel('Column density cm^-2')
    plt.xlabel('Radius (AU)')
    plt.text(rgrid[-5], (N_in-N_out)/2, 'N(r) = '+str(k_N)+' r^'+str(a_N))
    plt.savefig(figure_directory+'coldens_profile_'+name_for_files)
    
if corner_plot_query == 1:
    dyplot.cornerplot(dres, labels=['log_T', 'log_N', 'log_A'], label_kwargs={'fontsize':15})
    plt.savefig(figure_directory+'corner_plot_'+figure_name)

# trace plot
if corner_plot_query == 1:
    dyplot.traceplot(dres, labels=['log_T', 'log_N', 'log_A'], label_kwargs={'fontsize':15})
    plt.tight_layout()
    plt.savefig(figure_directory+'trace_plot_'+figure_name)
                                                               