In [None]:
import os
nthreads = 1
os.environ["OMP_NUM_THREADS"] = str(nthreads)
os.environ["OPENBLAS_NUM_THREADS"] = str(nthreads)
os.environ["MKL_NUM_THREADS"] = str(nthreads)

In [None]:
import matplotlib as mpl
import numpy as np
import matplotlib.pyplot as plt
import corner

In [None]:
%config InlineBackend.figure_format = 'retina'

In [None]:
import inferagni as ia

In [None]:
datadir = "../src/inferagni/data/"
gr = ia.grid.Grid(datadir)

In [None]:
gr.show_inputs()

In [None]:
print(gr.get_points())

In [None]:
opts = {
    "key1"    : "instellation",
    "key2"    : "frac_atm",
    "controls": {"logZ":0.5, "logCO":-1,  "frac_core":0.325, "Teff":4450},
    "show_scatter" : True,
    "show_isolines" : True,
}

fig = ia.plot.massrad_2d(gr, **opts)
plt.show(fig)

In [None]:
zkey = "instellation"
controls =  {"frac_atm":0.0063, "frac_core":0.325, "logZ":0, "Teff":4450, "logCO":-1}
itp_x, itp_y, itp_z = gr.interp_2d(controls=controls,
                                     zkey=zkey,
                                     resolution=100, method='linear')

fig,ax = plt.subplots(1,1,figsize=(6,4))
sm = ax.contourf(itp_x, itp_y, itp_z, levels=20, cmap='viridis', norm='log')
ax.plot(ia.plot.zeng[0.325][0],ia.plot.zeng[0.325][1], c='k', ls='--', label="Zeng+2019")
ax.set(xlabel=ia.util.varprops["mass_tot"].label, ylabel=ia.util.varprops["r_phot"].label, xlim=(0,10))
ax.set_title(f"{controls}", fontsize=9)
fig.colorbar(sm, label=ia.util.varprops[zkey].label).ax.set_yscale("log")
plt.show(fig)


zkey = "frac_atm"
controls={"frac_core":0.325, "logZ":0.0, "instellation":10.0, "Teff":4450, "logCO":-1}
itp_x, itp_y, itp_z = gr.interp_2d(controls=controls,
                                     zkey=zkey,
                                     resolution=100, method='linear')

fig,ax = plt.subplots(1,1,figsize=(6,4))
sm = ax.contourf(itp_x, itp_y, itp_z, levels=20, cmap='viridis')
ax.plot(ia.plot.zeng[0.325][0],ia.plot.zeng[0.325][1], c='k', ls='--', label="Zeng+2019")
ax.set(xlabel=ia.util.varprops["mass_tot"].label, ylabel=ia.util.varprops["r_phot"].label, xlim=(0,10))
ax.set_title(f"{controls}", fontsize=9)
fig.colorbar(sm, label=ia.util.varprops[zkey].label)
plt.show(fig)

In [None]:
vkey = "r_phot"
gr.interp_init(vkey=vkey)

### Test across masses, and different methods

In [None]:

fig,ax = plt.subplots(1,1)

idx = 100_000
eval_loc = {k: gr.data[k].values[idx] for k in gr.input_keys}

val_tru = ia.util.undimen(gr.data[vkey].values[idx],vkey)
ax.scatter(eval_loc["mass_tot"], val_tru, s=50, c='r', label="Truth", edgecolors='k')

for method in ("nearest","linear",):

    mass_arr = np.linspace(0.5, 12, 100)
    val_est = []
    for mass in mass_arr:
        eval_loc["mass_tot"] = mass
        val_est.append(gr.interp_eval(eval_loc, method=method, vkey=vkey))

    ax.plot(mass_arr, val_est, label=f"Interpolated ({method})")

ax.set(xlabel=ia.util.varprops["mass_tot"].label, ylabel=ia.util.varprops[vkey].label)
ax.legend()
ax.grid(zorder=-2, alpha=0.3)

plt.show()

### Test across masses, and other dimension

In [None]:

fig,ax = plt.subplots(1,1)

idx = 1000
eval_loc = {k: gr.data[k].values[idx] for k in gr.input_keys}

val_tru = ia.util.undimen(gr.data[vkey].values[idx], vkey)
ax.scatter(eval_loc["mass_tot"], val_tru, s=50, c='r', label="Truth", edgecolors='k', zorder=20)

z_key = "logZ"; z_arr = np.arange(-3, 3, 0.5)
norm = mpl.colors.Normalize(vmin=np.amin(z_arr), vmax=np.amax(z_arr), clip=True)
sm = mpl.cm.ScalarMappable(norm=norm, cmap='viridis')

for z in z_arr:

    mass_arr = np.linspace(0.5, 12, 20)
    val_est = []
    for mass in mass_arr:
        eval_loc["mass_tot"] = mass
        eval_loc[z_key] = z
        val_est.append(gr.interp_eval(eval_loc,vkey=vkey))

    ax.plot(mass_arr, val_est, label=f"{z_key} = {z:g}", c=sm.to_rgba(z), zorder=10)

ax.set(xlabel=ia.util.varprops["mass_tot"].label, ylabel=ia.util.varprops[vkey].label)
ax.legend(loc='upper left', bbox_to_anchor=(1,1))
ax.grid(zorder=-2, alpha=0.3)

plt.show()

### Test retrieval

In [None]:
gr.interp_init(vkey="t_surf")
gr.interp_init(vkey="mass_tot")
gr.interp_init(vkey="μ_phot")

In [None]:
obs = {
    "r_phot": [6.0, 0.1],
    "mass_tot": [3.0, 0.2],
    "μ_phot": [0.02, 0.005]
}

sampler = ia.retrieve.run(obs,gr, n_walkers=30, n_procs=1, n_steps=5000)

In [None]:
tau = sampler.get_autocorr_time()
print(tau)

In [None]:
# 1. Extract the flattened samples, discarding the initial burn-in period
samples = sampler.get_chain(discard=500, flat=True, thin=15)
print(f"Thinned samples: {samples.size}")
print("")

for i, key in enumerate(gr.input_keys):
    mcmc = np.percentile(samples[:, i], [16, 50, 84])
    q = np.diff(mcmc)
    print(f"{key:16s}: {mcmc[1]:8g} (+{q[1]:8g} / -{q[0]:.8g})")

print(ia.util.print_sep_min)
print("")

In [None]:
output_samples = []
print(samples.shape)
for i,sam in enumerate(samples):
    output_samples.append([gr.interp_eval(sam,vkey=k) for k in obs.keys() if k not in gr.input_keys])
output_samples = np.array(output_samples)

In [None]:
print(output_samples.shape)

In [None]:
all_samples = np.hstack([samples, output_samples])
all_labels = list(gr.input_keys) + [k for k in obs.keys() if k not in gr.input_keys]
print(all_labels)

In [None]:

print(f"Plotting {samples.shape[0]} samples...")

# 2. Create the corner plot
fig = corner.corner(
    # samples,
    # labels=[ia.util.varprops[k].label for k in gr.input_keys],

    all_samples,
    labels=all_labels,

    quantiles=[0.16, 0.5, 0.84], # Shows 1-sigma boundaries
    # titles=["" for k in gr.input_keys],
    show_titles=True,
    title_kwargs={"fontsize": 9},
    color="tab:blue"
)

plt.show()
print(ia.util.print_sep_min)
print("")

In [None]:


# Diagnostic plot: Check if walkers converged or are still wandering.
fig, axes = plt.subplots(len(gr.input_keys), figsize=(10, 7), sharex=True)

for i,k in enumerate(gr.input_keys):
    ax = axes[i]
    ax.plot(samples[:, :, i], "k", alpha=0.3)
    ax.set_xlim(0, len(samples))
    ax.set_ylabel(ia.util.varprops[k].label, fontsize=9)
    ax.yaxis.set_label_coords(-0.1, 0.5)

axes[-1].set_xlabel("Step Number")
plt.tight_layout()
plt.show()
print(ia.util.print_sep_min)
print("")