### Sample flow to get posterior samples

In [7]:
import torch
import numpy as np

from pyloric import simulate, create_prior, summary_stats
from joblib import delayed, Parallel
from copy import deepcopy
from multiprocessing import Pool
import time
import dill as pickle

In [8]:
_ = torch.manual_seed(12312313)

In [11]:
file = '../../../results/trained_neural_nets/inference/posterior_11deg.pickle'
with open(file, 'rb') as handle:
    flow = pickle.load(handle)

In [13]:
x_o = np.load('../../../results/experimental_data/xo_11deg.npy')

In [16]:
flow.sample((1,), x=x_o).shape

HBox(children=(FloatProgress(value=0.0, description='Drawing 1 posterior samples', max=1.0, style=ProgressStyl…




torch.Size([1, 31])

In [17]:
prior = create_prior()

In [29]:
data = np.load("../../results/11deg_post_pred/11_deg_post_pred_close_to_obs.npz")
good_stats = data["sample_stats"]
good_params = data["sample_params"]
good_seeds = data["sample_seeds"]

In [30]:
from stg_energy.fig2_histograms.energy import select_ss_close_to_obs
datafile = "../../results/prior_samples_after_classifier/samples_full_3.npz"
data = np.load(datafile)
ss_prior = data["stats"]

stats_mean = np.mean(ss_prior, axis=0)
stats_std = np.std(ss_prior, axis=0)

In [31]:
npz = np.load("../../results/experimental_data/summstats_prep845_082_0044.npz")
observation = npz["summ_stats"]

npz = np.load("../../results/experimental_data/trace_data_845_082_0044.npz")
t = npz["t"]

In [32]:
num_std = np.asarray(
    [0.02, 0.02, 0.02, 0.02, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]
)

In [36]:
start_time = time.time()
all_samples = []
for k in range(1958, 2000):
    start_time = time.time()
    num_samples_per_iter = 2520
    flow_samples = flow.sample(num_samples_per_iter)
    seeds = torch.randint(0, 10000, (num_samples_per_iter, 1))

    params_with_seeds = torch.cat((flow_samples, seeds), axis=1).detach().numpy()

    def simulator(params_set):
        out_target = simulate(
            deepcopy(params_set[:-1].astype(np.float64)),
            seed=int(params_set[-1]),
        )
        return stats(out_target)

    simulation_outputs = Parallel(n_jobs=12)(
        delayed(simulator)(batch)
        for batch in params_with_seeds
    )

    simulation_outputs = np.asarray(simulation_outputs)

    # Just for fun and print: compute how many are good.
    good_params_new, good_dat_new, good_seeds_new = select_ss_close_to_obs(
        flow_samples,
        simulation_outputs,
        seeds.squeeze(),
        observation,
        num_std=num_std,
        stats_std=stats_std[:15],
        new_burst_position_in_ss=True
    )

    print("num of good:  ", good_params_new.shape)

    np.savez(f"../../results/11deg_post_pred/11deg_5million_predictives_for_temp/simulated/11deg_5million_predictives_for_temp_{k}.npz", params=flow_samples.detach().numpy(), stats=simulation_outputs, seeds=seeds)
    print("Overall time for iteration", k, ":   ", time.time()-start_time)

num of good:   torch.Size([104, 31])
Overall time for iteration 1958 :    258.63058257102966
num of good:   torch.Size([128, 31])
Overall time for iteration 1959 :    253.99272656440735
num of good:   torch.Size([133, 31])
Overall time for iteration 1960 :    252.06354308128357
num of good:   torch.Size([118, 31])
Overall time for iteration 1961 :    250.91775631904602
num of good:   torch.Size([101, 31])
Overall time for iteration 1962 :    251.53859210014343
num of good:   torch.Size([98, 31])
Overall time for iteration 1963 :    252.8818256855011
num of good:   torch.Size([108, 31])
Overall time for iteration 1964 :    254.2624535560608
num of good:   torch.Size([116, 31])
Overall time for iteration 1965 :    253.21874523162842
num of good:   torch.Size([119, 31])
Overall time for iteration 1966 :    253.44809365272522
num of good:   torch.Size([127, 31])
Overall time for iteration 1967 :    252.56233310699463
num of good:   torch.Size([103, 31])
Overall time for iteration 1968 :   