### Sample flow to get posterior samples

In [1]:
import torch
import numpy as np

from pyloric import simulate, create_prior, stats
from joblib import delayed, Parallel
from copy import deepcopy
from multiprocessing import Pool
import time

In [2]:
_ = torch.manual_seed(123123)

In [3]:
data = np.load('/home/michael/Documents/STG_energy/results/flow/200411_flow.npz', allow_pickle=True)

In [4]:
flow = data['posterior'].tolist()
prior = create_prior()

In [None]:
start_time = time.time()
all_samples = []
for k in range(2000):
    start_time = time.time()
    num_samples_per_iter = 2520
    flow_samples = flow.sample(num_samples_per_iter)
    flow_samples = flow_samples[torch.exp(flow.log_prob(flow_samples)) > 0.0]
    
    seeds = torch.randint(0, 10000, (num_samples_per_iter, 1))
    params_with_seeds = torch.cat((flow_samples, seeds), axis=1).detach().numpy()
    
    def simulator(params_set):
        out_target = simulate(
            deepcopy(params_set[:-1].astype(np.float64)),
            seed=int(params_set[-1]),
        )
        return stats(out_target)

    simulation_outputs = Parallel(n_jobs=12)(
        delayed(simulator)(batch)
        for batch in params_with_seeds
    )

    simulation_outputs = np.asarray(simulation_outputs)
    
    np.savez(f"../../results/11deg_post_pred/11deg_5million_predictives_for_temp/simulated/11deg_5million_predictives_for_temp_{k}.npz", params=simulation_outputs, stats=simulation_outputs)
    print("Overall time for iteration", k, ":   ", time.time()-start_time)

Overall time for iteration 0 :    253.24006032943726
Overall time for iteration 1 :    250.19604897499084
Overall time for iteration 2 :    250.91529035568237
Overall time for iteration 3 :    252.21217226982117
Overall time for iteration 4 :    256.71840596199036
Overall time for iteration 5 :    257.5032002925873
Overall time for iteration 6 :    257.1373643875122
Overall time for iteration 7 :    255.10482954978943
Overall time for iteration 8 :    258.20604276657104
Overall time for iteration 9 :    256.9612317085266
Overall time for iteration 10 :    253.80482625961304
Overall time for iteration 11 :    258.8806004524231
Overall time for iteration 12 :    256.51437067985535
Overall time for iteration 13 :    254.38720631599426
Overall time for iteration 14 :    261.8491177558899
Overall time for iteration 15 :    258.2046139240265
