# Retrieval 101 

## PICASO Retrieval Philosophy: How is this "retrieval" code different from others? 

A "retrieval" is a word we use to describe the coupling of a forward model (e.g. picaso) with a statistical sampler (e.g. ultranest, dynesty, pymultinest) in order to conduct spectral inference -- getting constraint intervals for physical parameters. Check out this [great Bayesian workflow tutorial](https://johannesbuchner.github.io/UltraNest/example-sine-bayesian-workflow.html) which shows the simple exercise of setting up a "retrieval" for sine model with gaussian measurement errors. An atmospheric retrieval is no different from this example as we will show below. The major difference being of course, the use of an atmospheric model, in place of the simple sine curve. 

The main difference between PICASO and other codes is that PICASO's retrieval code is much more stripped down. You will see much more similarity between the tutorials that follow below, and the tutorials on the websites of the statistical samplers themselves. 

**Cons:** if you want more of a "plug and play" appraoch to retrieving atmospheric properties this is probably not the code for you and something like [POSEIDON](https://poseidon-retrievals.readthedocs.io/en/latest/content/notebooks/retrieval_basic.html#Creating-a-Retrieval-Model) might be a quicker setup for you. **Pros:** if you want to learn how to write your own parameterizations and retrieval setups then this route might be good for you. We will walk you through setting up several flavors of retrievals (grid only, grids+post processing, fully free retrivals). After you learn the basics, we will introduce some of our autobuild tools to writing scripts. 

## Making sense of grid fits, gridtrievals, free retrievals

The graphic below should help clarify the difference between these different spectral inference schemes. In order to create a spectra that we can compare to data we need, as input, planet/stellar parameters, a pressure-temperature profile, and set of chemical abundances as a function of pressure. The major difference between these three setups is in how you create the input. We can either use a set of radiative-convective models to pre-compute pressure-temperature and chemistry. Or, we can setup simple parameterizations (e.g. could be as simple as an isothermal lines) to build these inputs. The "magic sauce" of retrievals is all in how we build these parameterizations. 

![image](retrievalfig/retrievalfig.001.jpeg)


## Retrieval Tutorial Overview

Our full retrieval tutorial takes you through: 

1. The basics of our retrievals workflow: the data, the model set, the prior set, the likelihood function (**this tutorial**)
2. The basics of a retrieval analysis: corner plots, Bayesian evidence, 1-3 sigma banded spectra (**this tutorial**)
3. Implemenation: Grid Fitting -- Using only pre-computed grids to get best fit grid values
4. Implementation: Grid-trieval -- Using a pre-computed grid and post-processing values such as clouds, added chemistry, etc
5. Implementation: Free retrieval -- Building parameterizations 

All of these retrieval tutorials are done using the [WASP-17 b Grant et al data.](https://zenodo.org/records/8360121) and associated pre-computed PICASO model grid. 

# The Basics of Retrieval Workflow

All PICASO retrieval scripts will include at least: 

1. function `get_data` : a function to get your spectral data
2. class `param_set` : a class that contains the set of free parameters you want to retrieve
3. class `guesses_set` : an option class that will help you test your models 
4. class `model_set` : a class that contains the set of models you want to test
5. class `prior_set` : a class that define the prior
6. function `loglikelihood` : a function that calls your model and computes the likelihood 
7. statistical sampling run script : code that executes the sampler (e.g. ultranest)

Below we break these down with the simple function of fitting a line to a spectrum. If you understand the basics here of fitting some model to a spectrum, and analyzing the results, you will be well positioned to move onto the next tutorials where we swap the simple line function for a picaso function. 

In [None]:
import numpy as np
import os 
import picaso.justdoit as jdi
import picaso.analyze as lyz
import xarray as xr
import matplotlib.pyplot as plt

## Step 1) Develop function to get data

Let's create a function to pull our data where all we need to do is declare who the data is from and it pulls it for us automatically.

Note: this format is only a recommendation and you can change any part of this to fit your needs 

In [None]:
def get_data(): 
    """
    Create a function to process your data in any way you see fit.
    Here we are using the ExoTiC-MIRI data 
    https://zenodo.org/records/8360121/files/ExoTiC-MIRI.zip?download=1
    But no need to download it.

    Checklist
    ---------
    - your function returns a spectrum that will be in the same units as your picaso model (e.g. rp/rs^2, erg/s/cm/cm or other) 
    - your function retuns a spectrum that is in ascending order of wavenumber 
    - your function returns a dictionary where the key specifies the instrument name (in the event there are multiple)

    Returns
    -------
    dict: 
        dictionary key: wavenumber (ascneding), flux or transit depth, and error.
        e.g. {'MIRI LRS':[wavenumber, transit depth, transit depth error], 'NIRSpec G395H':[wavenumber, transit depth, transit depth error]}
    """
    dat = xr.load_dataset(jdi.w17_data())
    #build nice dataframe so we can easily 
    final = jdi.pd.DataFrame(dict(wlgrid_center=dat.coords['central_wavelength'].values,
                transit_depth=dat.data_vars['transit_depth'].values,
                transit_depth_error=dat.data_vars['transit_depth_error'].values))

    #create a wavenumber grid 
    final['wavenumber'] = 1e4/final['wlgrid_center']

    #always ensure we are ordered correctly
    final = final.sort_values(by='wavenumber').reset_index(drop=True)

    #return a nice dictionary with the info we need 
    returns = {'MIRI_LRS': [final['wavenumber'].values, 
             final['transit_depth'].values  ,final['transit_depth_error'].values]   }
    return returns

## Step 2) Define Free Parameters

In what follows we build three classes that will (in the future) help us keep track of all the models we test for a single planet case. For this simple tutorial let's just do the simplest thing and retrieve the simplest model possible: 

$$y = mx + b$$

In [None]:
class param_set:
    """
    This is for book keeping what parameters you have run in each retrieval.
    It helps if you keep variables uniform.
    
    Checklist
    ---------
    - Make sure that the order of variables here match how you are unpacking your cube in the model_set class and prior_set
    - Make sure that the variable names here match the function names in model_set and prior_set
    """
    line=['m','b','log_err_inf'] 


## Step 3) Define Initial Guesses

In testing, it is very useful to check that it is grabbing the right parameters before doing a full analysis. Also if you choose to use emcee instead of MultiNest these can serve as your starting values for your chain. 

In [None]:
class guesses_set: 
    """
    Optional! 

    Tips
    ----
    - Usually you might have some guess (however incorrect) of what the answer might be. You can use this in the testing phase!
    """
    line=[0,0.016633,-1,1] #here I am guessing a zero slope, and the measured transit depth reported from exo.mast, and a small error inflation term

## Step 4) Define Model Set

Here, we are defining the full model. This is essentially prepping and making it easy to digest for Ultranest's `cube` usage.

In [None]:
class model_set:
    """
    This is your full model set. It will include all the functions you want to test
    for a particular data set.

    Tips
    ----
    - if you keep the structure of all your returns identically you will thank yourself later. 
      For exmaple, below I always return x,y,dict of instrument offsets,dict of error inflation, if exists

    Checklist
    ---------
    - unpacking the cube should unpack the parameters you have set in your param_set class. I like to use 
    list indexing with strings so I dont have to rely on remembering a specific order
    """     
    def line(cube): 
        wno_grid = np.linspace(600,3000,int(1e4)) #in the future this will be defined by the picaso opacity db
        m = cube[param_set.line.index('m')] 
        b = cube[param_set.line.index('b')] 
        err_inf = {'MIRI_LRS':10**cube[param_set.line.index('log_err_inf')] }
        y = m*wno_grid + b 
        offsets = {} #I like to keep the returns of all my model sets the same 
        return wno_grid,y,offsets,err_inf

## Step 5) Define Prior Set

Finally, we are storing all the priors for Ultranest to use.

In [None]:
class prior_set:
    """
    Store all your priors. You should have the same exact function names in here as
    you do in model_set and param_set

    Checklist
    ---------
    - Make sure the order of the unpacked cube follows the unpacking in your model 
      set and in your parameter set. 
    """   
    def line(cube):#,ndim, nparams):
        params = cube.copy()
        #slope min max
        min = -1e-5
        max = 1e-5
        i=0;params[i] = min + (max-min)*params[i];i+=1
        #intercept min max
        min = 0.015
        max = 0.02
        params[i] = min + (max-min)*params[i];i+=1
        #log err inflation min max 
        min = -10
        max = 3
        params[i] = min + (max-min)*params[i];i+=1
        return params                


## Step 6) Define Likelihood Function

Most likelihood functions have the same form (see for example the formalism in [emcee](https://emcee.readthedocs.io/en/stable/tutorials/line/#maximum-likelihood-estimation)). 

In [None]:
def loglikelihood(cube):
    """
    Log_likelihood function that ultimately is given to the sampler
    Note if you keep to our same formats you will not have to change this code move 

    Tips
    ----
    - Remember how we put our data dict, error inflation, and offsets all in dictionary format? Now we can utilize that 
    functionality if we properly named them all with the right keys! 

    Checklist
    --------- 
    - ensure that error inflation and offsets are incorporated in the way that suits your problem 
    - note there are many different ways to incorporate error inflation! this is just one example 
    """
    #compute model spectra
    resultx,resulty,offset_all,err_inf_all = MODEL(cube) # we will define MODEL below 

    #initiate the four terms we willn eed for the likelihood
    ydat_all=[];ymod_all=[];sigma_all=[];extra_term_all=[];

    #loop through data (if multiple instruments, add offsets if present, add err inflation if present)
    for ikey in DATA_DICT.keys(): #we will also define DATA_DICT below
        xdata,ydata,edata = DATA_DICT[ikey]
        xbin_model , y_model = jdi.mean_regrid(resultx, resulty, newx=xdata)#remember we put everything already sorted on wavenumber

        #add offsets if they exist
        offset = offset_all.get(ikey,0) #if offset for that instrument doesnt exist, return 0
        ydata = ydata+offset 

        #add error inflation if they exist
        err_inf = err_inf_all.get(ikey,0) #if err inf term for that instrument doesnt exist, return 0
        sigma = edata**2 + (err_inf)**2 #there are multiple ways to do this, here just adding in an extra noise term
        if err_inf !=0: 
            #see formalism here for example https://emcee.readthedocs.io/en/stable/tutorials/line/#maximum-likelihood-estimation
            extra_term = np.log(2*np.pi*sigma)
        else: 
            extra_term=sigma*0

        ydat_all.append(ydata);ymod_all.append(y_model);sigma_all.append(sigma);extra_term_all.append(extra_term); 

    ymod_all = np.concatenate(ymod_all)    
    ydat_all = np.concatenate(ydat_all)    
    sigma_all = np.concatenate(sigma_all)  
    extra_term_all = np.concatenate(extra_term_all)

    #compute likelihood
    loglike = -0.5*np.sum((ydat_all-ymod_all)**2/sigma_all + extra_term_all)
    return loglike

## Step 7) Check models, likelihoods, priors! 

Do not undersestimate the importance of this step before moving forward. You always want to ensure that your model is returning sensible values before you jump into running your sampler. 

Ask yourself: 

1. Do your random tests about your prior approximately go through the data
2. Do they seem skewed? If so, maybe you need to adjust your prior?
3. Do the likelihood values track with the models? E.g. lower likelihoods for bad models, higher likelihoods for good models 

In [None]:
#we can easity grab all the important pieces now that they are neatly stored in a class structure 
DATA_DICT = get_data()
PARAMS = getattr(param_set,'line')
MODEL = getattr(model_set,'line')
PRIOR = getattr(prior_set,'line')
GUESS = getattr(guesses_set,'line')

In [None]:
import matplotlib.pyplot as plt
plt.figure()
#lets plot the data 
for ikey in DATA_DICT.keys(): 
    plt.errorbar(x=DATA_DICT[ikey][0], y=DATA_DICT[ikey][1], yerr=DATA_DICT[ikey][2], marker='o', ls=' ',label='Grant data')

ntests = 10 #lets do 10 random tests 
for i in range(ntests): 
    cube = np.random.uniform(size=len(PARAMS))
    params_evaluations = PRIOR(cube)
    x,y,off,err = MODEL(params_evaluations)
    loglike = loglikelihood(params_evaluations)
    plt.plot(x,y,label=str(int(loglike)))

guessx,guessy,off,err = MODEL(GUESS)
guess_log = loglikelihood(GUESS)
plt.plot(guessx,guessy,color='black',label='guess '+ str(int(guess_log)))
plt.legend()

Looks pretty good! We might be tiny bit skewed toward higher y intercept values but overall things look good. 

## Step 8) Run the statistical sampler!! 

Now we can finally move forward with running ultranest. Though once you have gotten this far you will have the skills to implement other samplers as well. They all generally have the same format.

In [None]:
import ultranest

In [None]:
sampler = ultranest.ReactiveNestedSampler(PARAMS, loglikelihood, PRIOR)
result = sampler.run()

# The Basics of Retrieval Analysis

In future tutorials you will see some pre-defined functions to help you with doing the following analyses. However, on a first go-around in this simple example we will put the native code here so you can see what is going on under the hood. 

## Posterior predictive checks

First let us check that our samples are returning a sensible model for our data. 

In [None]:
from ultranest.plot import PredictionBand

plt.figure()
first =True
for params in result['samples']:
    x,y,off,err = MODEL(params)
    if first:
        band = PredictionBand(1e4/x);first=False   #transforming xaxis to microns 
    band.add(y)

band.line(color='g')#median model

#lets plot the 1, 2, and 3 sigma confidence interval from these samples 
for q ,key in zip([k/100/2 for k in [68.27, 95.45, 99.73]], ['1sig','2sig','3sig']): 
        band.shade(q=q,color='g', alpha=0.5)

for ikey in DATA_DICT.keys(): 
    plt.errorbar(x=1e4/DATA_DICT[ikey][0], y=DATA_DICT[ikey][1], yerr=DATA_DICT[ikey][2], marker='o', ls=' ',label='Grant data')
    
plt.xlabel('micron')
plt.ylabel('transit depth');

Not bad! We are nearly ready to implement a real atmosphere model to fit the data! 


## Parameter posterior probability distribution checks

What to check for in a corner plot: 

1. Are your 1D marginalized posterior probability distributions Gaussian (bell shaped)?
2. Are there correlations between your parameters?
3. Does your posterior probability distribution hit the limit of your prior (check the x axis of each distribution)?

For our example below: m and b appear Gaussian. However, the error inflation term looks to be poorly constrained. We only have an upper limit on the number. 

In [None]:
from ultranest.plot import cornerplot
cornerplot(result)

## Exercise to check your understanding 

1. Go back through param_set, prior_set and try rerunning the retrieval that does not include the error inflation term. 

TIP: If you follow the formalism you will not have to change the likelihood code. You should only have to make edits to the param_set and  prior_set (and guesses_set if you are using this to check your model)

# Short cut to get grid fitting retrieval template in script form 

Now that you understand the basics of running a simple line model let's introduce some PICASO tools to help you quickly setup a new retrieval script, without the hassle of a notebook. 

In [None]:
import picaso.retrieval as pr

In [None]:
rtype='line' #first lets specify the retrieval type 'line' (we will introduce the other options in future tutorials)
sript_name='run_test.py' #speciy a script name 
sampler_output = '/data/test/ultranest/line' #what folder do you want your ultranest output to go to? 
pr.create_template(rtype,sript_name,sampler_output)

Open up `run_test.py` and modify what you need. We have marked key areas you might want to modify with "CHANGEME"

Running with mpiexec with 5 cpu: 

    >> mpiexec -n 5 python -m mpi4py run_test.py

## Further Reading

https://johannesbuchner.github.io/UltraNest/example-sine-bayesian-workflow.html