In [None]:
import argparse
import pickle
from copy import deepcopy
import torch
from sbi.inference import SNPE

from sbi.utils.user_input_checks import get_batch_loop_simulator, process_custom_prior
from sbi.simulators.simutils import simulate_in_batches
from sbi.utils.get_nn_models import posterior_nn

from sbi_smfs.simulator import get_simulator_from_config
from sbi_smfs.inference.priors import get_priors_from_config
from sbi_smfs.inference.embedding_net import SimpleCNN
from sbi_smfs.utils.config_utils import get_config_parser
from sbi_smfs.analysis.plot_posterior import plot_spline_ensemble, plot_spline_mean_with_error, plot_spline

In [None]:
device = 'cuda'
config_file = "../../sbi_smfs_experiments/experiment_1/fixed_Dx.config"
prior = get_priors_from_config(config_file, device=device)
simulator = get_simulator_from_config(config_file)
config = get_config_parser(config_file)

In [None]:
true_parameters = torch.tensor([-1.0000,  0.5,  5.0000,  1.0000,  4.0000,  5.0000,  4.0000,  4.0000, 7.0000,  3.0000,  2.0000,  1.0000,  5.0000], dtype=torch.float64)

In [None]:
observation = simulator(true_parameters)

In [None]:
torch.save(observation, "../../sbi_smfs_experiments/experiment_1/obs_exp1.pt")

In [None]:
pre_trained_posterior = torch.load("posterior.pt", map_location='cpu')

In [None]:
neural_posterior = lambda theta, x: deepcopy(pre_trained_posterior.posterior_estimator)

In [None]:
cnn_net = SimpleCNN(
    len(config.getlistint("SUMMARY_STATS", "lag_times")),
    4,
    2,
    config.getint("SUMMARY_STATS", "num_bins"),
    len(config.getlistint("SUMMARY_STATS", "lag_times")),
)

kwargs_flow = {
    "num_blocks": 2,
    "dropout_probability": 0.0,
    "use_batch_norm": False,
}

neural_posterior = posterior_nn(
    model="nsf",
    hidden_features=100,
    num_transforms=5,
    num_bins=10,
    embedding_net=cnn_net,
    z_score_x="none",
    **kwargs_flow,
)

In [None]:
inference = SNPE(prior=prior, density_estimator=neural_posterior, device=device)
simulator = get_batch_loop_simulator(simulator)
proposal = prior

In [None]:
num_rounds=3
num_sim_per_round=200
num_workers=24

In [None]:
for idx_round in range(num_rounds):
    theta = proposal.sample((num_sim_per_round,))

    x = simulate_in_batches(
        simulator, theta.cpu(), sim_batch_size=5, num_workers=num_workers
    )

    inference = inference.append_simulations(
        theta, x, proposal=proposal, data_device="cpu"
    )

    density_estimator = inference.train(
        show_train_summary=True,
        validation_fraction=0.15,
        training_batch_size=50,
        learning_rate=0.0005,
        stop_after_epochs=20,
    )

    posterior = inference.build_posterior(density_estimator)
    proposal = posterior.set_default_x(observation)

In [None]:
samples = posterior.sample((10000,), x=observation.cuda()) 

In [None]:
samples[:, 2:] = samples[:, 2:] - torch.mean(samples[:, 2:], dim=1).reshape(-1, 1)

In [None]:
plot_spline_ensemble(samples, 1000, config_file)
plot_spline(true_parameters[2:] - 3.5, config_file, color='red')