In [None]:
from floppity import Retrieval, helpers
import numpy as np
import matplotlib.pyplot as plt
plt.rc('font', size=16)
import torch
from corner import corner
import cloudpickle as pickle
from tqdm import trange
from floppity.simulators import read_ARCiS_input, ARCiS

#### Set-up retrieval

Let's first define the ARCiS input file we will be using, and read the 
observations and parameters to be fit from it. We will also pass the input file and the output directory to the ARCiS simulator (a python wrapper that calls your ARCiS installation).

In [None]:
arcis_input='/Users/floppityflappity/Work/WISE1828/input_j1828_chem.dat'
pars, obs_list = read_ARCiS_input(arcis_input)

ARCiS_kwargs= dict(
    input_file = arcis_input,
    output_dir = '/Users/floppityflappity/Work/WISE1828/test_231025',
)

Let's now define the training and flow hyperparameters.

In [None]:
training_kwargs= dict(
    stop_after_epochs = 40,
    num_atoms = 10,
    learning_rate=3e-4,
    force_first_round_loss=True,
    use_combined_loss=True
)

flow_kwargs=dict(
    flow='nsf',
    bins=4,
    transforms=8,
    blocks=2,
    hidden=32,
    dropout=0.01,
    max_num_epochs=500,
)

Now let's create our retrieval object and load the observations and priors into it.

In [None]:
R = Retrieval(ARCiS)
R.parameters=pars
R.get_obs(obs_list)

Additional parameters can be added as shown below. 
The post_process tag means that they're not passed onto the simulator but the spectra are modified a posteriori using one of the functions in postprocessing.py

In [None]:
# R.add_parameter('RV', -40,-30, post_process=True)
# R.add_parameter('vrot', 50,70, post_process=True)

#### Run retrieval

In [None]:
R.run(flow_kwargs=flow_kwargs, training_kwargs=training_kwargs, simulator_kwargs=ARCiS_kwargs,
      resume=False, n_threads=4,  n_rounds=5, n_samples_init=256, n_samples=64)

#### Plot loss

In [None]:
%matplotlib inline
plt.plot(R.inference._summary['best_validation_loss'], marker='o')
plt.xlabel('Round')
plt.ylabel('Validation Loss')
plt.show()

#### Plot posterior

This is just to get all the parameter names and priors for the figure labels

In [None]:
full = []
for key in R.parameters:
    ranges_m = R.parameters[key]['min']
    ranges_p = R.parameters[key]['max']
    full.append((ranges_m, ranges_p))

In [None]:
CORNER_KWARGS=dict(
    smooth=0.6,
    plot_density=True,
    hist_bin_factor=1,
    plot_contours=True,
    show_titles=True,
    color='mediumpurple',
    range=full,
)
fig=R.plot_corner(-1,**CORNER_KWARGS)

## Posterior diagnostics

### Posterior predictive check

In [None]:
%matplotlib inline
xs = R.post_x
# Xs=np.concatenate(list(xs.values()), axis=1)

plt.figure(figsize=(40,10))
for key in R.obs.keys():
    plt.errorbar(x=R.obs[key][:,0], y=R.obs[key][:,1], yerr=R.obs[key][:,2], c='r', lw=1, zorder=0)
for i in trange(len(xs['obs1'])):
    for key in R.obs.keys():
        plt.plot(R.obs[key][:,0], xs[key][i], c='b', alpha=0.01, zorder=1)
# plt.xlim(0.5,2)
# plt.ylim(-1e-6,2e-5)
plt.xlabel('Wavelength')
plt.ylabel('Flux')
plt.show()

### Check best fitting model from the training set

In [None]:
bf_idx, bf_chi=helpers.find_best_fit(R.obs, xs)

In [None]:

plt.figure(figsize=(24,8))
for key in R.obs.keys():
    plt.errorbar(x=R.obs[key][:,0], y=R.obs[key][:,1], yerr=R.obs[key][:,2], c='k', lw=1, zorder=0)
    plt.plot(R.obs[key][:,0], xs[key][bf_idx], c='r', alpha=1, zorder=1)
# plt.xlim(0.9,2)
# plt.ylim(-1e-6,2e-5)
