In [None]:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import colors
import matplotlib.cm as cm

import starry
import astropy.units as u

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import *

import arviz as az
import seaborn as sns
import pandas as pd
import pickle as pkl
from scipy.linalg import svd

from jwst_eclipse_mapping.utils import *

np.random.seed(42)
starry.config.lazy = False
starry.config.quiet = True

numpyro.enable_x64(True)
numpyro.set_host_device_count(4)

In [None]:
%matplotlib inline
%run notebook_setup.py

In [None]:
def sim_snapshot_to_sh(path, wav_map, ydeg=20, temp_offset=-450):
    data = np.loadtxt(path)

    nlat = 64
    nlon = 128

    lons = np.linspace(-180, 180, nlon)
    lats = np.linspace(-90, 90, nlat)

    lon_grid, grid = np.meshgrid(lons, lats)

    temp_grid = np.zeros_like(lon_grid)

    for i in range(nlat):
        for j in range(nlon):
            temp_grid[i, j] = data.reshape((nlat, nlon))[i, j]

    temp_grid = (
        np.roll(temp_grid, int(temp_grid.shape[1] / 2), axis=1) + temp_offset
    )

    x_list = []
    map_tmp = starry.Map(ydeg)

    # Evaluate at fewer points for performance reasons
    idcs = np.linspace(0, len(wav_map) - 1, 5).astype(int)
    for i in idcs:
        I_grid = np.pi * planck(temp_grid, wav_map[i])
        map_tmp.load(I_grid, force_psd=True)
        x_list.append(map_tmp._y * map_tmp.amp)

    # Interpolate to full grid
    x_ = np.vstack(x_list).T
    x_interp_list = [
        np.interp(wav_map, wav_map[idcs], x_[i, :]) for i in range(x_.shape[0])
    ]
    x = np.vstack(x_interp_list)

    # Initialize planet map
    ydeg = 20
    map_planet = starry.Map(ydeg=ydeg, nw=len(wav_map))
    map_planet[1:, :, :] = x[1:] / x[0]
    map_planet.amp = x[0]

    return map_planet

def inferred_map_intensity_to_physical_units(I_raw):
    # Star spectral radiance integrated over solid angle and bandpass
    I_star = np.pi * integrate_planck_over_filter(params_s["T"].value, filt)

    # Rescale the intensity of the planet map to physical units
    I_planet = I_raw * I_star * (params_s["r"] / params_p["r"]) ** 2
    return np.array(I_planet)


def rescale_simulated_map_units(map_spectral, wav_map, filt):
    map = starry.Map(20)
    thr = np.interp(wav_map, filt[0], filt[1])
    x_int = np.trapz(
        map_spectral.amp * map_spectral._y * thr, x=wav_map, axis=1
    )
    map[1:, :] = x_int[1:] / x_int[0]
    map.amp = x_int[0] * 1e-06
    return map

def plot_intensity_slice_simulated_map(ax, map_planet_int,  lat=None, lon=None, color='C0'):
    if lat is None:
        grid = np.linspace(-90, 90, 200)
    else:
        grid = np.linspace(-90, 90, 200)

    if lat is None:
        I_true = map_planet_int.intensity(lat=grid, lon=lon)/np.pi
    else:
        I_true = map_planet_int.intensity(lat=lat, lon=grid)/np.pi
    
    # Plot 
    ax.plot(grid, I_true, f'{color}-', lw=2.)
    ax.set(xticks=np.arange(-90, 135, 45));

def plot_intensity_slice_inferred_map(ax, samples, lat=None, lon=None, color='C0'):
    if lat is None:
        grid = np.linspace(-90, 90, 200)
    else:
        grid = np.linspace(-90, 90, 200)

    # Intensity profile for inferred map
    map = starry.Map(ydeg_inf)
    I_list = []
    for i in range(100):
        idcs = np.random.randint(0, len(samples['x']), 200)
        x = samples['x'][idcs][i]
        map[1:, :] = x[1:]/x[0]
        map.amp = x[0]
        if lat is None:
            I = map.intensity(lat=grid, lon=lon)
        else:
            I = map.intensity(lat=lat, lon=grid)
        I_rescaled = inferred_map_intensity_to_physical_units(I)
        I_list.append(I_rescaled)
    

    # Plot 
    perc_16 = np.percentile(I_list, [16, 50, 84], axis=0)[0]
    perc_50 = np.percentile(I_list, [16, 50, 84], axis=0)[1]
    perc_84 = np.percentile(I_list, [16, 50, 84], axis=0)[2]

    ax.fill_between(
        grid, perc_16, perc_84, alpha=0.4, color=f"{color}"
    )
    ax.plot(grid, perc_50, lw=2., color=f"{color}")
    ax.set(xticks=np.arange(-90, 135, 45));

    
def compute_preimage_operator(A):
    rank = np.linalg.matrix_rank(A)
    _, _, VT = svd(A)
#     N = VT[rank:].T @ VT[rank:]  # null space operator
    R = VT[:rank].T @ VT[:rank]  # row space operator
    return R

def get_preimage_map(map, R, ydeg):
    x_ = map._y*map.amp
    x = R @ x_
    x = x[:(ydeg + 1)**2]
    map = starry.Map(ydeg, nw=map.nw)
    map[1:, :, :] = x[1:, :]/x[0]
    map.amp = x[0]
    return map

def compute_design_matrix(t, params_p, params_s, texp, ydeg_inf):

    # Star map parameters
    star = starry.Primary(
        starry.Map(ydeg=1, udeg=2),
        r=params_s["r"].value,
        m=params_s["m"].value,
        length_unit=u.Rsun,
        mass_unit=u.Msun,
    )
    star.map[1] = params_s["u"][0]
    star.map[2] = params_s["u"][1]

    planet = starry.Secondary(
        starry.Map(ydeg=20, inc=params_p["inc"].value,),
        ecc=params_p["ecc"],
        omega=params_p["omega"].value,
        r=params_p["r"].value,
        porb=params_p["porb"].value,
        prot=params_p["prot"].value,
        t0=params_p["t0"].value,
        inc=params_p["inc"].value,
        theta0=180,
        length_unit=u.Rsun,
        angle_unit=u.deg,
        time_unit=u.d,
    )
    sys_fit = starry.System(star, planet, texp=(texp.to(u.d)).value)

    # Design matrix
    A_full = sys_fit.design_matrix(t)
    A = A_full[:, 4:]

    return A, A_full

In [None]:
# Load data
# path_to_dir = "../data/output/hd189_f444w_ydeg_7_snr_20.8_texp_2.04/"
# path_to_dir = "../data/output/hd189_f322w2_ydeg_7_snr_23.8_texp_3.07/"
path_to_dir = "../data/output/hd189_f444w_ydeg_7_snr_20.8_texp_2.04/"

filter_name = "f444w"
ydeg_inf = 7

params_p = np.load('../data/system_parameters/hd189_orbital_params_planet.p', allow_pickle=True)
params_s = np.load('../data/system_parameters/hd189_orbital_params_star.p', allow_pickle=True)
obs_dict_list = np.load(os.path.join(path_to_dir, 'obs_dict_list.pkl'), allow_pickle=True)
samples_list = np.load(os.path.join(path_to_dir, 'samples_list.pkl'), allow_pickle=True)

obs1, obs2, obs3, obs4 = obs_dict_list
t = obs1['t']
fsim_list = [obs1['fsim'], obs2['fsim'], obs3['fsim'], obs4['fsim']]
fsim_unif_list = [obs1['fsim_unif'], obs2['fsim_unif'], obs3['fsim_unif'], obs4['fsim_unif']]

# Load filter
filt = load_filter(name=f"{filter_name}")
mask = filt[1] > 0.002

# Wavelength grid for starry map (should match filter range)
wav_map = np.linspace(filt[0][mask][0], filt[0][mask][-1], 80)

In [None]:
ydeg_sim = 20
A, A_full = compute_design_matrix(t, params_p, params_s, 2.04*u.s, 20)
R = compute_preimage_operator(A)

In [None]:
# Load simulation snapshots as starry maps
sim1 = np.load("../data/hydro_snapshots_raw/T42_temp_0.1bar_25days_ylm.npz")
sim2 = np.load("../data/hydro_snapshots_raw/T42_temp_0.1bar_100days_ylm.npz")
sim3 = np.load("../data/hydro_snapshots_raw/T42_temp_0.1bar_300days_ylm.npz")
sim4 = np.load("../data/hydro_snapshots_raw/T42_temp_0.1bar_500days_ylm.npz")

def initialize_map(ydeg, nw, x):
    map = starry.Map(ydeg, nw=nw)
    map[1:, :, :] = x[1:, :] / x[0]
    map.amp = x[0]
    return map

map1_sim = initialize_map(ydeg_sim, len(sim1['wav_grid']), sim1['x'])
map2_sim = initialize_map(ydeg_sim, len(sim2['wav_grid']), sim2['x'])
map3_sim = initialize_map(ydeg_sim, len(sim3['wav_grid']), sim3['x'])
map4_sim = initialize_map(ydeg_sim, len(sim4['wav_grid']), sim4['x'])

# Convert to temperature
map1_sim_temp = spectral_radiance_map_to_bbtemp_map(map1_sim, sim1['wav_grid'], resol=200)
map2_sim_temp = spectral_radiance_map_to_bbtemp_map(map2_sim, sim2['wav_grid'], resol=200)
map3_sim_temp = spectral_radiance_map_to_bbtemp_map(map3_sim, sim3['wav_grid'], resol=200)
map4_sim_temp = spectral_radiance_map_to_bbtemp_map(map4_sim, sim4['wav_grid'], resol=200)
maps_sim_temp = [map1_sim_temp, map2_sim_temp, map3_sim_temp, map4_sim_temp]

# Convert multi-spectral SH maps into SH maps integrated over the bandpass
map_planet1_int = rescale_simulated_map_units(map1_sim, sim1['wav_grid'], filt)
map_planet2_int = rescale_simulated_map_units(map2_sim, sim2['wav_grid'], filt)
map_planet3_int = rescale_simulated_map_units(map3_sim, sim3['wav_grid'], filt)
map_planet4_int = rescale_simulated_map_units(map4_sim, sim4['wav_grid'], filt)
maps_sim_integrated = [map_planet1_int, map_planet2_int, map_planet3_int, map_planet4_int]

In [None]:
# Compute mean inferred maps from posterior samples
mean_map1 = get_mean_map(ydeg_inf, samples_list[0]['x'], projection=None, nsamples=200, resol=200)
mean_map2 = get_mean_map(ydeg_inf, samples_list[1]['x'], projection=None, nsamples=200, resol=200)
mean_map3 = get_mean_map(ydeg_inf, samples_list[2]['x'], projection=None, nsamples=200, resol=200)
mean_map4 = get_mean_map(ydeg_inf, samples_list[3]['x'], projection=None, nsamples=200, resol=200)

# Convert those maps to temperature
mean_temp_map1 = inferred_intensity_to_bbtemp(mean_map1, filt, params_s, params_p)
mean_temp_map2 = inferred_intensity_to_bbtemp(mean_map2, filt, params_s, params_p)
mean_temp_map3 = inferred_intensity_to_bbtemp(mean_map3, filt, params_s, params_p)
mean_temp_map4 = inferred_intensity_to_bbtemp(mean_map4, filt, params_s, params_p)
inferred_temp_maps = [mean_temp_map1, mean_temp_map2, mean_temp_map3, mean_temp_map4]

In [None]:
# Preimage temperature maps
map1_preimage = get_preimage_map(map1_sim, R, 10)
map2_preimage = get_preimage_map(map2_sim, R, 10)
map3_preimage = get_preimage_map(map3_sim, R, 10)
map4_preimage = get_preimage_map(map4_sim, R, 10)

map1_preimage_temp = spectral_radiance_map_to_bbtemp_map(map1_preimage, sim1['wav_grid'], resol=200)
map2_preimage_temp = spectral_radiance_map_to_bbtemp_map(map2_preimage, sim2['wav_grid'], resol=200)
map3_preimage_temp = spectral_radiance_map_to_bbtemp_map(map3_preimage, sim3['wav_grid'], resol=200)
map4_preimage_temp = spectral_radiance_map_to_bbtemp_map(map4_preimage, sim4['wav_grid'], resol=200)
map_preimage_temp_list = [map1_preimage_temp, map2_preimage_temp, map3_preimage_temp, map4_preimage_temp]

In [None]:
# # Plot simulated light curves
# for i, (fsim, fsim_unif) in enumerate(zip(fsim_list, fsim_unif_list)):
#     res = fsim - fsim_unif
#     ax_lcs[i].plot(t, (fsim_unif - 1)*1e06, 'k', label='uniform map')
#     ax_lcs[i].plot(t, (fsim - 1)*1e06, 'C1--', label='simulation snapshot')
#     ax_res[i].plot(t, res*1e06,'C1-')
    
#     for val in (-750,  -500,  -250, 0):
#         ax_lcs[i].axhline(val, color='grey', lw=0.5, zorder=-100)

#     for val in (-30,  0,  30):
#         ax_res[i].axhline(val, color='grey', lw=0.5, zorder=-100)
        
#     xl = 0.15
#     ax_lcs[i].set(xlim=(-xl, xl))
#     ax_res[i].set(xlim=(-xl, xl), ylim=(-65, 65))
    
# ax_lcs[0].set_yticks([-750, -500, -250, 0])
# ax_res[0].set_yticks([-30, 0, 30])

# for a in (ax_lcs[0], ax_res[0]):
#     a.spines["left"].set_visible(False)
#     a.spines["top"].set_visible(False)
#     a.spines["right"].set_visible(False)
#     a.spines["bottom"].set_visible(False)
#     a.set_xticks([])
#     a.tick_params(axis='y', colors='grey', length=0, labelsize=9)

# for a in ax_lcs[1:]:
#     a.axis('off')

# for a in ax_res[1:]:
#     a.axis('off')

# # for a in ax_lcs[1:]:
# #     a.axis('off')
# # for a in ax_res[1:]:
# #     a.axis('off')
# #     ax_res[i].axis('off')

In [None]:
fig = plt.figure(figsize=(10, 9))

gs = fig.add_gridspec(
    nrows=6, 
    ncols=4 + 1 + 4 + 1 + 4 + 1 + 4,
    height_ratios=[4,2, 4, 2,4,2.],
    hspace=0., wspace=0.
)

# Axes for simulated maps
ax_maps_sim =[
    fig.add_subplot(gs[0, :4]),
    fig.add_subplot(gs[0, 5:9]),
    fig.add_subplot(gs[0, 10:14]),
    fig.add_subplot(gs[0, 15:19])
]

# Axes for preimages of simulated maps
ax_maps_preim =[
    fig.add_subplot(gs[2, :4]),
    fig.add_subplot(gs[2, 5:9]),
    fig.add_subplot(gs[2, 10:14]),
    fig.add_subplot(gs[2, 15:19])
]


# Axes for inferred maps
ax_maps = [
    fig.add_subplot(gs[4, :4]),
    fig.add_subplot(gs[4, 5:9]),
    fig.add_subplot(gs[4, 10:14]),
    fig.add_subplot(gs[4, 15:19])
]


# Axes for samples
ax_s1 = [fig.add_subplot(gs[5, i]) for i in range(4)] 
ax_s2 = [fig.add_subplot(gs[5, 5 + i]) for i in range(4)] 
ax_s3 = [fig.add_subplot(gs[5, 10 + i]) for i in range(4)] 
ax_s4 = [fig.add_subplot(gs[5, 15 + i]) for i in range(4)] 

# Plot simulated maps
map = starry.Map(20)
norm = colors.Normalize(vmin=850, vmax=1150)

map.show(image=maps_sim_temp[0], colorbar=False, ax=ax_maps_sim[0], cmap="OrRd", norm=norm)
map.show(image=maps_sim_temp[1], colorbar=False, ax=ax_maps_sim[1], cmap="OrRd", norm=norm)
map.show(image=maps_sim_temp[2], colorbar=False, ax=ax_maps_sim[2], cmap="OrRd", norm=norm)
map.show(image=maps_sim_temp[3], colorbar=False, ax=ax_maps_sim[3], cmap="OrRd", norm=norm)

# Plot preimages
map = starry.Map(12)
map.show(image=map_preimage_temp_list[0], colorbar=False, ax=ax_maps_preim[0], cmap="OrRd", norm=norm)
map.show(image=map_preimage_temp_list[1], colorbar=False, ax=ax_maps_preim[1], cmap="OrRd", norm=norm)
map.show(image=map_preimage_temp_list[2], colorbar=False, ax=ax_maps_preim[2], cmap="OrRd", norm=norm)
map.show(image=map_preimage_temp_list[3], colorbar=False, ax=ax_maps_preim[3], cmap="OrRd", norm=norm)


# Plot inferred maps
map.show(
    image=inferred_temp_maps[0], cmap="OrRd", colorbar=False, ax=ax_maps[0], norm=norm
)
map.show(
    image=inferred_temp_maps[1], cmap="OrRd", colorbar=False, ax=ax_maps[1], norm=norm
)
map.show(
    image=inferred_temp_maps[2], cmap="OrRd", colorbar=False, ax=ax_maps[2], norm=norm
)
map.show(
    image=inferred_temp_maps[3], cmap="OrRd", colorbar=False, ax=ax_maps[3], norm=norm
)

# Mini plots with posterior samples
idcs = np.random.randint(0, 500, 4)
map = starry.Map(ydeg_inf)
res = 100

for a, samples in zip([ax_s1, ax_s2, ax_s3, ax_s4], samples_list):
    for i in range(4):
        x = samples['x'][idcs[i]]
        map[1:, :] = x[1:]/x[0]
        map.amp = x[0]
        tmp = inferred_intensity_to_bbtemp(map.render(res=res), filt, params_s, params_p)
        map.show(
         ax=a[i], image=tmp, cmap="OrRd", norm=norm, grid=False
        )

# Colorbar
cbaxes = fig.add_axes([0.94, 0.33, 0.018, 0.4]) 
cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap="OrRd"), cax=cbaxes)  
cb.set_label(label=r"blackbody temperature [K]", fontsize=10)
cb.ax.tick_params(labelsize=10)

fig.text(0.515, 0.96, 'Simulation snapshots', ha='center', va='center', fontweight="bold", fontsize=15)
fig.text(0.515, 0.67, 'Simulation snapshot preimages', ha='center', va='center', fontweight="bold", fontsize=15)
fig.text(0.515, 0.4, 'Inferred maps', ha='center', va='center', fontweight="bold", fontsize=15)

ax_maps_sim[0].set_title(r"$t=25$ days")
ax_maps_sim[1].set_title(r"$t=100$ days")
ax_maps_sim[2].set_title(r"$t=300$ days")
ax_maps_sim[3].set_title(r"$t=500$ days")

fig.savefig(
    os.path.join(path_to_dir, "maps_comparison.png"), bbox_inches="tight", dpi=300
)

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(10, 6), sharey=True)

for i, lat in enumerate([30, 0, -30]):
    a = ax[i, :]
    plot_intensity_slice_simulated_map(a[0], map_planet1_int, lat=lat, color='k')
    plot_intensity_slice_inferred_map(a[0], samples_list[0], lat=lat, color=f'C{i}')

    plot_intensity_slice_simulated_map(a[1], map_planet2_int, lat=lat, color='k')
    plot_intensity_slice_inferred_map(a[1], samples_list[1], lat=lat, color=f'C{i}')

    plot_intensity_slice_simulated_map(a[2], map_planet3_int, lat=lat, color='k')
    plot_intensity_slice_inferred_map(a[2], samples_list[2], lat=lat, color=f'C{i}')

    plot_intensity_slice_simulated_map(a[3], map_planet4_int, lat=lat, color='k')
    plot_intensity_slice_inferred_map(a[3], samples_list[3], lat=lat, color=f'C{i}')
    
for a in ax[:2, :].reshape(-1):
    a.set_xticklabels([])

ax[0, 0].set_title('$t=25$ days')
ax[0, 1].set_title('$t=100$ days')
ax[0, 2].set_title('$t=300$ days')
ax[0, 3].set_title('$t=500$ days')

for a in ax.reshape(-1):
    a.grid(alpha=0.5)
    
ax[0, 0].set(yticklabels=[], ylabel='Intensity\n(30 deg lat.)')
ax[1, 0].set(yticklabels=[], ylabel='Intensity\n(0 deg lat.)')
ax[2, 0].set(yticklabels=[], ylabel='Intensity\n(-30 deg lat.)')
 
fig.text(0.5, 0.02, 'Longitude [deg]', ha='center', va='center')

fig.savefig(
    os.path.join(path_to_dir, "intensity_longitudinal_slices.png"), bbox_inches="tight", dpi=300
)

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(10, 6), sharey=True)

for i, lon in enumerate([30, 0, -30]):
    a = ax[i, :]
    plot_intensity_slice_simulated_map(a[0], map_planet1_int, lon=lon, color='k')
    plot_intensity_slice_inferred_map(a[0], samples_list[0], lon=lon, color=f'C{i}')

    plot_intensity_slice_simulated_map(a[1], map_planet2_int, lon=lon, color='k')
    plot_intensity_slice_inferred_map(a[1], samples_list[1], lon=lon, color=f'C{i}')

    plot_intensity_slice_simulated_map(a[2], map_planet3_int, lon=lon, color='k')
    plot_intensity_slice_inferred_map(a[2], samples_list[2], lon=lon, color=f'C{i}')

    plot_intensity_slice_simulated_map(a[3], map_planet4_int, lon=lon, color='k')
    plot_intensity_slice_inferred_map(a[3], samples_list[3], lon=lon, color=f'C{i}')
    
for a in ax[:2, :].reshape(-1):
    a.set_xticklabels([])

ax[0, 0].set_title('$t=25$ days')
ax[0, 1].set_title('$t=100$ days')
ax[0, 2].set_title('$t=300$ days')
ax[0, 3].set_title('$t=500$ days')

for a in ax.reshape(-1):
    a.grid(alpha=0.5)
    
ax[0, 0].set(yticklabels=[], ylabel='Intensity\n(30 deg long.)')
ax[1, 0].set(yticklabels=[], ylabel='Intensity\n(0 deg long.)')
ax[2, 0].set(yticklabels=[], ylabel='Intensity\n(-30 deg long.)')
 
fig.text(0.5, 0.02, 'Latitude [deg]', ha='center', va='center')

fig.savefig(
    os.path.join(path_to_dir, "intensity_latitudinal_slices.png"), bbox_inches="tight", dpi=300
)

In [None]:
# def compute_ps(map):
#     ydeg = map.ydeg
#     ps = np.array([np.sum(map[l, :]**2)/(1 + 2*l) for l in range(ydeg)])
#     return ps

# def compute_ps_from_samples(ydeg, samples):
#     map = starry.Map(ydeg)
#     ps_list = []
        
#     for i in np.random.randint(0, len(samples['x']), 300):
#         x = samples['x'][i]
#         map[1:, :] = x[1:]/x[0]
#         map.amp = x[0]

#         ps = np.array([np.sum(map[l, :]**2)/(1 + 2*l) for l in range(ydeg)])
#         ps_list.append(ps)

#     ps_ = np.stack(ps_list)
        
#     return np.mean(ps_, axis=0), np.std(ps_, axis=0)

# def powerspectrum_comparison_plot(maps_sim_integrated, samples_list);
#     ps1_sim = compute_ps(maps_sim_integrated[0])
#     ps2_sim = compute_ps(maps_sim_integrated[1])
#     ps3_sim = compute_ps(maps_sim_integrated[2])
#     ps4_sim = compute_ps(maps_sim_integrated[3])


#     ps1_mean, ps1_std = compute_ps_from_samples(ydeg_inf, samples_list[0])
#     ps2_mean, ps2_std = compute_ps_from_samples(ydeg_inf, samples_list[1])
#     ps3_mean, ps3_std = compute_ps_from_samples(ydeg_inf, samples_list[2])
#     ps4_mean, ps4_std = compute_ps_from_samples(ydeg_inf, samples_list[3])
    
#     ls_inf = np.arange(ydeg_inf)

#     fig, ax = plt.subplots(1,2, figsize=(8,4), sharey=True)
#     ax[0].plot(ls_inf, ps1_sim[:len(ls_inf)], 'C0o-', alpha=0.6)
#     ax[0].plot(ls_inf, ps2_sim[:len(ls_inf)], 'C1o-', alpha=0.6)
#     ax[0].plot(ls_inf, ps3_sim[:len(ls_inf)], 'C2o-', alpha=0.6)
#     ax[0].plot(ls_inf, ps4_sim[:len(ls_inf)], 'C3o-', alpha=0.6)

#     ax[1].errorbar(ls_inf, ps1_mean, ps1_std, marker='o', color='C0', alpha=0.6, label='$t=25$ d')
#     ax[1].errorbar(ls_inf, ps2_mean, ps2_std, marker='o', color='C1', alpha=0.6, label='$t=100$ d')
#     ax[1].errorbar(ls_inf, ps3_mean, ps3_std, marker='o', color='C2', alpha=0.6, label='$t=300$ d')
#     ax[1].errorbar(ls_inf, ps4_mean, ps4_std, marker='o', color='C3', alpha=0.6, label='$t=500$ d')

#     for a in ax:
#         a.grid(alpha=0.5)
#         a.set(yscale='log', xlabel='$l$', xticks=ls_inf)
#     ax[0].set(ylabel='$\log_{10} S(l)$')
#     ax[0].set_title('Simulated power spectrum')

#     ax[1].set_title('Inferred power spectrum')
#     ax[1].legend()