In [13]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sbi_smfs.utils.config_utils import get_config_parser
from sbi_smfs.inference.sequential_posterior import train_sequential_posterior
from sbi_smfs.simulator.simulator import get_simulator_from_config
from sbi_smfs.analysis.plot_posterior import plot_spline_ensemble, plot_spline_mean_with_error, plot_spline

In [44]:
config_file = "../../sbi_smfs_experiments/experiment_1/config_files/fixed_Dx.config"
config_file2 = "../../sbi_smfs_experiments/experiment_1/indip_Dx.config"

In [45]:
config = get_config_parser(config_file)

In [49]:
simulator = get_simulator_from_config(config_file)

In [50]:
true_parameters = torch.tensor([-1.0000,  0.5,  5.0000,  1.0000,  4.0000,  5.0000,  4.0000,  4.0000, 7.0000,  3.0000,  2.0000,  1.0000,  5.0000], dtype=torch.float64)

In [51]:
observation = simulator(true_parameters)

In [None]:
posterior = train_sequential_posterior(config_file, 1, 2000, 24, observation, device='cuda')

In [None]:
torch.save(observation, "../../sbi_smfs_experiments/experiment_1/obs_exp1.pt")

In [None]:
samples = posterior.sample((10000,), x=observation.cuda()) 

In [None]:
samples[:, 2:] = samples[:, 2:] - torch.mean(samples[:, 2:], dim=1).reshape(-1, 1)

In [None]:
plot_spline_ensemble(samples, 1000, config_file)
plot_spline(true_parameters[2:] - 3.5, config_file, color='red')

In [None]:
plot_spline_mean_with_error(samples, config_file, alpha=0.025, ylims=(-2, 10))
plot_spline(true_parameters[2:] - 0.5, config_file, color='red', ylims=(-10, 10))

In [None]:
_ = plt.hist(samples[:, 0].cpu().numpy(), bins=100, label='Dq')
_ = plt.hist(samples[:, 1].cpu().numpy(), bins=100, label='k')
plt.legend()

In [None]:
posterior2 = train_sequential_posterior(config_file2,  15, 1500, 24, observation, device='cuda')


In [None]:
samples2 = posterior2.sample((100000,), x=observation.cuda())

In [None]:
samples2[:, 2:] = samples2[:, 2:] - torch.mean(samples2[:, 2:], dim=1).reshape(-1, 1)

In [None]:
plot_spline_ensemble(samples2, 1000, config_file2)
plot_spline(true_parameters[2:] - 3.5, config_file, color='red')

In [None]:
plot_spline_mean_with_error(samples2, config_file2, alpha=0.025, ylims=(-2, 10))
plot_spline(true_parameters[2:] - 0.5, config_file2, color='red', ylims=(-10, 10))

In [None]:
_ = plt.hist(samples2[:, 0].cpu().numpy(), bins=np.linspace(-2, 1), label='Dx')
_ = plt.hist(samples2[:, 1].cpu().numpy(), bins=np.linspace(-2, 1), label='Dq')
plt.legend()

In [None]:
_ = plt.hist(samples2[:, 2].cpu().numpy(), bins=np.linspace(0, 1), label='k')
plt.legend()