In [None]:
!pip install seaborn

In [None]:
%load_ext autoreload
%autoreload 2


import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns
import pandas as pd
from pathlib import Path
from fnope.utils.misc import  get_data_dir, get_output_dir, get_project_root
from plotting_utils import colors,method_names




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = get_data_dir()
output_dir = get_output_dir()
root_dir = get_project_root()

simformer_idx = method_names.index("simformer")
simformer_color = colors[simformer_idx]
fno_idx = method_names.index("FNOPE")
fnope_color = colors[fno_idx]


In [None]:
methods = ["FNOPE_always_equispaced_False","Simformer"]
nsims = [1_000, 10_000, 100_000] #number of simulations of run
eval_points = [20, 40] #whether we condition on 20 or 40 points
init_conditions = ["True", "False"] #Whether we use the initial condition estimate from simformer for the predictive simulations (for both methods) or the prior
out_dir = get_output_dir()
experiment_folder = Path(out_dir/"sir_experiment" / "FNOPE")





In [None]:
df = pd.read_pickle(experiment_folder/"summary.pkl")

In [None]:
# We have 4 plots we can generate for the choices we have
n_cond = 40 #20 or 40
use_simformer_init = True #True or False
# #which of the 100 observations will be plotted
obs_index = 56
# results = df[(df["eval_num"] == n_cond) & (df["simformer_initial_condition"] == str(use_simformer_init))]
results = df
# results = df[::3]
results


In [None]:
sbcs = results["sbcs"].values[3]
sbcs

In [None]:


sbc_results = {}
sbc_mean = np.zeros((len(methods), len(nsims)))
sbc_SE = np.zeros((len(methods), len(nsims)))

tarp_results = {}
tarp_mean = np.zeros((len(methods), len(nsims)))
tarp_SE = np.zeros((len(methods), len(nsims)))

predictive_mse_results = {}
predictive_mse_mean = np.zeros((len(methods), len(nsims)))
predictive_mse_SE = np.zeros((len(methods), len(nsims)))

for ii, method in enumerate(methods):
    sbc_results[method] = {}
    tarp_results[method] = {}
    predictive_mse_results[method] = {}
    for kk, nsim in enumerate(nsims):
        sbc_results[method][nsim] = []
        tarp_results[method][nsim] = []
        predictive_mse_results[method][nsim] = []
        temp_sbcs = results[(results['method'] == method) & (results['nsim'] == nsim)]['sbcs']
        for ll in range(temp_sbcs.shape[0]):
            s_sbcs = temp_sbcs.iloc[ll]
            sbc_results[method][nsim].extend(s_sbcs)
        sbc_mean[ii, kk] = np.mean(np.array(sbc_results[method][nsim]))
        sbc_SE[ii, kk] = np.std(np.array(sbc_results[method][nsim]))/np.sqrt(len(sbc_results[method][nsim]))
        
        temp_tarps = results[(results['method'] == method) & (results['nsim'] == nsim)]['tarps']
        for mm in range(temp_tarps.shape[0]):
            s_tarps = temp_tarps.iloc[mm]
            tarp_results[method][nsim].extend([s_tarps])
        tarp_mean[ii, kk] = np.mean(np.array(tarp_results[method][nsim]))
        tarp_SE[ii, kk] = np.std(np.array(tarp_results[method][nsim]))/np.sqrt(len(tarp_results[method][nsim]))

        temp_predictive_mse = results[(results['method'] == method) & (results['nsim'] == nsim)]['predictive_mses']
        for nn in range(temp_predictive_mse.shape[0]):
            s_predictive_mse = temp_predictive_mse.iloc[nn]
            predictive_mse_results[method][nsim].extend(s_predictive_mse)
        predictive_mse_mean[ii, kk] = np.mean(np.array(predictive_mse_results[method][nsim]))
        predictive_mse_SE[ii, kk] = np.std(np.array(predictive_mse_results[method][nsim]))/np.sqrt(len(predictive_mse_results[method][nsim]))

In [None]:
print("SBC mean")
print(sbc_mean)
print("SBC SE")
print(sbc_SE)
print("TARP mean")
print(tarp_mean)
print("TARP SE")
print(tarp_SE)
print("Predictive MSE mean")
print(predictive_mse_mean)
print("Predictive MSE SE")
print(predictive_mse_SE)

fno_sbc_mean = sbc_mean[0]
fno_sbc_SE = sbc_SE[0]
simformer_sbc_mean = sbc_mean[1]
simformer_sbc_SE = sbc_SE[1]
fno_tarp_mean = tarp_mean[0]
fno_tarp_SE = tarp_SE[0]
simformer_tarp_mean = tarp_mean[1]
simformer_tarp_SE = tarp_SE[1]
fno_predictive_mse_mean = predictive_mse_mean[0]
fno_predictive_mse_SE = predictive_mse_SE[0]
simformer_predictive_mse_mean = predictive_mse_mean[1]
simformer_predictive_mse_SE = predictive_mse_SE[1]

In [None]:
#Now read summary files for posterior and predictive plots

fno_results_file = out_dir/ f"sir_experiment/FNOPE/num_sim_100000_run_3/fno_predictive_summary.pkl"
with open(fno_results_file, "rb") as f:
    fno_results = pickle.load(f)
print(fno_results.keys())

simformer_results_file = out_dir / f"sir_experiment/FNOPE/num_sim_100000_run_3/simformer_predictive_summary.pkl"
with open(simformer_results_file, "rb") as f:
    simformer_results = pickle.load(f)
print(simformer_results.keys())

In [None]:
# Also need to load the metadata for the evaluation from the simformer results file

theta_test = simformer_results["theta_test"]
test_x = simformer_results["x_test"]
fno_posterior_samples = fno_results["posterior_samples"]
fno_predictive_samples = fno_results["posterior_predictive_samples"]
simformer_posterior_samples = simformer_results["posterior_samples"]
simformer_predictive_samples = simformer_results["posterior_predictive_samples"]

In [None]:
test_data_name = f"posterior100k_samples_{n_cond}_time_points.npz"
test_data = np.load(data_dir/ "sir"/ test_data_name)

test_times_theta = torch.from_numpy(test_data["meta_data"][:,2:2+n_cond]).to(device)
test_times_x = torch.from_numpy(test_data["meta_data"][:,2+n_cond:2+2*n_cond]).to(device)

eval_times = test_times_theta[obs_index].detach().cpu().numpy()
test_times_x_with_0 = torch.cat((torch.Tensor([0.0]).to(device), test_times_x[obs_index]), dim=0).detach().cpu().numpy()

In [None]:
#Compute means and stds for methods

fno_cont_samples = fno_posterior_samples[:,obs_index,2:]
simformer_cont_samples = simformer_posterior_samples[:,obs_index,2:]
fno_finite_samples = fno_posterior_samples[:,obs_index,:2]
simformer_finite_samples = simformer_posterior_samples[:,obs_index,:2]

fno_predictives = fno_predictive_samples[:,obs_index]
simformer_predictives = simformer_predictive_samples[:,obs_index]

theta_test_cont = theta_test[:,2:]
theta_test_finite = theta_test[:,:2]



fno_mean = fno_cont_samples.mean(axis=0)
fno_std = fno_cont_samples.std(axis=0)

fno_finite_mean = fno_finite_samples.mean(axis=0)
fno_finite_std = fno_finite_samples.std(axis=0)

fno_predictive_mean = fno_predictives.mean(axis=0)
fno_predictive_std = fno_predictives.std(axis=0)


simformer_mean = simformer_cont_samples.mean(axis=0)
simformer_std = simformer_cont_samples.std(axis=0)
simformer_finite_mean = simformer_finite_samples.mean(axis=0)
simformer_finite_std = simformer_finite_samples.std(axis=0)

simformer_predictive_mean = simformer_predictives.mean(axis=0)
simformer_predictive_std = simformer_predictives.std(axis=0)



In [None]:
#### Calculate the minimal achievable SBC from uniform ranks ####

num_posterior_samples = 1000
n_sbc = 100

ranks = np.random.randint(0, num_posterior_samples, size=(n_sbc, n_cond+2))

coverage_values = torch.Tensor(ranks) / num_posterior_samples

atcs = []
absolute_atcs = []

for dim_idx in range(coverage_values.shape[1]):
    # calculate empirical CDF via cumsum and normalize
    hist, alpha_grid = torch.histogram(
    coverage_values[:, dim_idx], density=True, bins=30
    )
    # add 0 to the beginning of the ecp curve to match the alpha grid
    ecp = torch.cat([torch.Tensor([0]), torch.cumsum(hist, dim=0) / hist.sum()])
    atc = (ecp - alpha_grid).mean().item()
    absolute_atc = (ecp - alpha_grid).abs().mean().item()
    atcs.append(atc)
    absolute_atcs.append(absolute_atc)

atcs = torch.tensor(atcs)
absolute_atcs = torch.tensor(absolute_atcs)
print(absolute_atcs)

mean_absolute_atc = absolute_atcs.mean().numpy()
print(mean_absolute_atc)

In [None]:
# Create a figure
from plotting_utils import colors
fig_version = "v1"
simformer_idx = method_names.index("simformer")
#simformer_color = colors[simformer_idx]
simformer_color = 'turquoise'
fno_idx = method_names.index("FNOPE")
fnope_color = colors[fno_idx]
#with plt.rc_context(fname=root_dir/"plots"/"matplotlibrc"):
with plt.rc_context(fname="matplotlibrc"):
    kwargs_text = {"fontsize": "10", "font": "Arial", "weight": "800"}

    fig = plt.figure(figsize=(3.47, 4))

    def truncate(x, decimals=1):
        return np.floor(x * 10**decimals) / 10**decimals

    # Define the grid layout
    gs = GridSpec(14, 2, width_ratios=[3,1.2], height_ratios=[1]*3 + [0.01]*2 + [1]*9,figure=fig, hspace=5., wspace=0.6)

    fig.text(-0.04, 0.9, 'a', ha='center',**kwargs_text)
    fig.text(-0.04, 0.63, 'b', ha='center', va='bottom' ,**kwargs_text)
    fig.text(0.58, 0.63, 'c', ha='center', va='bottom', **kwargs_text)


    # Posterior pairplot
    ax1 = fig.add_subplot(gs[0:3,1])
    ax1.set_xlabel("Recovery rate")
    ax1.set_ylabel("Death rate", labelpad=2)


    # KDE plot for fno_finite_samples
    ax1.scatter(theta_test_finite[obs_index, 0], 
                theta_test_finite[obs_index, 1], 
                color='black', label='GT',zorder=10)

    sns.kdeplot(x=fno_finite_samples[:, 0], 
                y=fno_finite_samples[:, 1], 
                ax=ax1,
                color = fnope_color,
                levels=4,
                fill=False)
    # KDE plot for simformer_finite_samples
    sns.kdeplot(x=simformer_finite_samples[:, 0], 
                y=simformer_finite_samples[:, 1], 
                ax=ax1,
                color=simformer_color,
                alpha=0.75,
                levels=4,
                fill=False)

    #Set the min and max xlabels to be the minimum of fno_finite_samples and simformer_finite_samples
    # x_min = min(fno_finite_samples[:, 0].min().item(), simformer_finite_samples[:, 0].min().item())
    # x_max = max(fno_finite_samples[:, 0].max().item(), simformer_finite_samples[:, 0].max().item())
    # y_min = min(fno_finite_samples[:, 1].min().item(), simformer_finite_samples[:, 1].min().item())
    # y_max = max(fno_finite_samples[:, 1].max().item(), simformer_finite_samples[:, 1].max().item())
    # x_min = np.round(x_min, 2)
    # x_max = np.round(x_max, 2)

    # y_min = np.round(y_min, 2)
    # y_max = np.round(y_max, 2)
    # if x_min == x_max:
    #     x_max += 0.05
    # if y_min == y_max:
    #     y_max += 0.05
    # print(x_min, x_max)
    # print(y_min, y_max)
    # ax1.set_xlim(x_min, x_max)
    # ax1.set_ylim(y_min, y_max)
    x_min,x_max = 0.2,0.4
    y_min,y_max = 0.2,0.4

    ax1.set_xticks([x_min, x_max])
    ax1.set_yticks([y_min, y_max])
    ax1.set_xticklabels([x_min, x_max])
    ax1.set_yticklabels([y_min, y_max])
    # ax1.xaxis.set_ticks_position('top')     # Move ticks to the top
    # ax1.xaxis.set_label_position('top')     # Move axis label to the top



    # Functional posterior
    ax2 = fig.add_subplot(gs[0:3, 0])
    ax2.set_ylabel("Contact rate")
    # Plot fno_mean with shaded region for fno_std
    ax2.plot(eval_times, fno_mean, label='FNO', color=fnope_color)
    ax2.fill_between(eval_times, 
                    fno_mean - fno_std, 
                    fno_mean + fno_std, 
                    color=fnope_color, alpha=0.2)

    # Plot simformer_mean with shaded region for simformer_std
    ax2.plot(eval_times, simformer_mean, label='Simformer', color=simformer_color)
    ax2.fill_between(eval_times, 
                    simformer_mean - simformer_std, 
                    simformer_mean + simformer_std, 
                    color=simformer_color, alpha=0.4)


    ax2.plot(eval_times,
                theta_test_cont[obs_index], 
                color = "black",
                label = "GT",
                linestyle='dashed')

    ymin, ymax = ax2.get_ylim()
    ax2.set_yticks([ymin, ymax])
    ymin= 0.01
    ax2.set_yticklabels([f'{ymin:.1f}', f'{ymax:.1f}'])
    ax2.set_xlabel("Time [days]")



    #Metric plots on right (Predictive MSE and SBC)
    ax3 = fig.add_subplot(gs[5:9, 1])
    mse_scale = 3.0

    ax3.errorbar(nsims,
                simformer_predictive_mse_mean,
                yerr=simformer_predictive_mse_SE,
                fmt='o',
                linestyle='-',
                color=simformer_color,
                )

    ax3.errorbar(nsims,
                fno_predictive_mse_mean,
                yerr=fno_predictive_mse_SE,
                fmt='o',
                linestyle='-',
                color=fnope_color,
                )

    ax3.set_yticks([0, 8e-3])
    ax3.tick_params(labelbottom=False)

    from matplotlib.ticker import ScalarFormatter


    # Force scientific notation with offset
    formatter = ScalarFormatter(useMathText=True)
    formatter.set_powerlimits((-3, -3))  # Force 1e-3 scaling
    ax3.yaxis.set_major_formatter(formatter)

    # Optional: move the offset text (e.g. ×10⁻³) to top or bottom
    ax3.ticklabel_format(axis='y', style='scientific')
    ax3.yaxis.offsetText.set_visible(True)  # This is the ×10⁻³ label


    #ax3.set_ylabel('MSE',labelpad=18)
    ax3.set_ylabel('MSE', labelpad=8)
    # Draw and then adjust the offset text
    fig.canvas.draw()  # Ensure offsetText is created before modifying

    offset = ax3.yaxis.get_offset_text()
    # offset.set_fontsize(10)                      # Make it smaller
    offset.set_horizontalalignment('left')      # Align left
    offset.set_x(-0.15)  



    ax4 = fig.add_subplot(gs[10:14, 1],sharex=ax3)

    ax4.errorbar(nsims,
                simformer_sbc_mean,
                yerr=simformer_sbc_SE,
                fmt='o',
                linestyle='-',
                color=simformer_color,
                )

    ax4.errorbar(nsims,
                fno_sbc_mean,
                yerr=fno_sbc_SE,
                fmt='o',
                linestyle='-',
                color=fnope_color,
                )
        
    #ax4.hlines(mean_absolute_atc, 500, 200000, linestyle=':', color='black', label=f'lower\nbound')
    ax4.hlines(mean_absolute_atc, 1e3, 1e5, linestyle=':', color='black', label=f'lower\nbound')
    ax4.set_xscale("log")
    ax4.set_xlabel('# simulations')
    ax4.set_ylabel('SBC EoD', labelpad=2)
    ax4.set_xticks(nsims)
    ax4.set_xlim(500,200_000)
    ax4.minorticks_off()
    ax4.set_yticks([0, 0.2])
    ax4.set_ylim(0, 0.2)
    ax4.legend(handlelength=1.1,
            loc='upper right',
            bbox_to_anchor=(1.1, 1.1)
                )


    #Posterior Predictive Plots
    channel_labels = ['Infected', 'Recovered', 'Deceased']
    # Create the third panel on the bottom right
    for channel in range(simformer_predictive_mean.shape[0]):
        ax = fig.add_subplot(gs[2+3*(channel+1):2+3*(channel+2), 0],sharex=ax2)

        ax.set_ylabel(channel_labels[channel])
        # Plot for FNO predictive mean and std
        ax.plot(test_times_x_with_0, 
                    fno_predictive_mean[channel], 
                    color = fnope_color,
                    linestyle='solid',
                    label = "FNOPE")
        ax.fill_between(test_times_x_with_0, 
                            fno_predictive_mean[channel] - fno_predictive_std[channel], 
                            fno_predictive_mean[channel] + fno_predictive_std[channel], 
                            color = fnope_color,
                            alpha=0.2)

        # Plot for Simformer predictive mean and std
        ax.plot(test_times_x_with_0,
                simformer_predictive_mean[channel],  
                color = simformer_color,
                label = "Simformer")
                
        ax.fill_between(test_times_x_with_0,
                        simformer_predictive_mean[channel] - simformer_predictive_std[channel], 
                        simformer_predictive_mean[channel] + simformer_predictive_std[channel], 
                        color = simformer_color,
                        alpha=0.4)
        ax.plot(test_times_x_with_0[1:],
                    test_x[obs_index, channel], 
                    color = "black",
                    label = "Data",
                    linestyle='',
                    marker = "x",
                    markersize = 3)
        
        ymin, ymax = ax.get_ylim()
        ymax = truncate(ymax, 2)
        ax.set_yticks([ymin, ymax])
        ymin = 0.01

        ax.set_yticklabels([f'{ymin:.1f}', f'{ymax:.1f}'])

        if channel < simformer_predictive_mean.shape[0]-1:
            plt.setp(ax.get_xticklabels(), visible=False)
            ax.tick_params(labelbottom=False)

            if channel == 0:
                label_order = ['FNOPE', 'Simformer', 'Data']
                handles, labels = ax.get_legend_handles_labels()
                label_to_handle = dict(zip(labels, handles))
                ordered_handles = [label_to_handle[label] for label in label_order]

                # Set legend for axs[0]
                ax.legend(ordered_handles,
                          label_order,
                          frameon=False,
                          ncol=1,
                          handlelength=1.5,
                          bbox_to_anchor=(1.05, 1.15))
        else:
            ax.set_xlabel("Time [days]",labelpad=8)


    for ax in fig.axes:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_position(('outward', 5))     # Move y-axis slightly left
        ax.spines['bottom'].set_position(('outward', 1))     # Move y-axis slightly left

    # ax1.spines['top'].set_visible(True)
    # ax1.spines['bottom'].set_visible(False)
    # # ax1.spines['top'].set_position(('outward', 5))     # Move y-axis slightly left
    # # plt.tight_layout()
    plt.show()

    # fig.savefig(f"sir_plots/sir_summary_n_eval_{n_cond}_simformer_init_{str(use_simformer_init)}_{fig_version}.svg", bbox_inches='tight')
    # fig.savefig(f"sir_plots/sir_summary_n_eval_{n_cond}_simformer_init_{str(use_simformer_init)}_{fig_version}.pdf", bbox_inches='tight')
    # fig.savefig(f"sir_plots/sir_summary_n_eval_{n_cond}_simformer_init_{str(use_simformer_init)}_{fig_version}.png", bbox_inches='tight', dpi=300)