### Sample flow to get posterior samples

In [1]:
import torch
import numpy as np

from pyloric import simulate, create_prior, stats
from joblib import delayed, Parallel
from copy import deepcopy
from multiprocessing import Pool
import time

In [2]:
_ = torch.manual_seed(12312313)

In [3]:
data = np.load('/home/michael/Documents/STG_energy/results/flow/200411_flow.npz', allow_pickle=True)

In [4]:
flow = data['posterior'].tolist()

In [5]:
print(flow.sample(1))

tensor([[ 3.4304e-02,  6.9182e-04,  1.3943e-03,  1.4528e-02,  2.4569e-03,
          2.3212e-02,  4.9989e-06,  1.2571e-06,  7.3174e-02,  5.5572e-04,
          4.4999e-03,  2.8524e-02,  6.1185e-03,  4.8477e-02,  2.4848e-05,
          1.7146e-05,  1.6718e-01,  4.1757e-03,  1.2410e-03,  2.0514e-02,
          1.5498e-03,  6.5672e-02,  1.2471e-05,  9.5230e-06, -9.0803e+00,
         -1.7436e+01, -1.0947e+01, -1.0716e+01, -1.6232e+01, -1.7659e+01,
         -1.2110e+01]], grad_fn=<CatBackward>)


In [6]:
prior = create_prior()

In [7]:
data = np.load("../../results/11deg_post_pred/11_deg_post_pred_close_to_obs.npz")
good_stats = data["sample_stats"]
good_params = data["sample_params"]
good_seeds = data["sample_seeds"]

In [9]:
from stg_energy.fig2_histograms.energy import select_ss_close_to_obs
datafile = "../../results/prior_samples_after_classifier/samples_full_3.npz"
data = np.load(datafile)
ss_prior = data["stats"]

stats_mean = np.mean(ss_prior, axis=0)
stats_std = np.std(ss_prior, axis=0)

In [10]:
npz = np.load("../../results/experimental_data/summstats_prep845_082_0044.npz")
observation = npz["summ_stats"]

npz = np.load("../../results/experimental_data/trace_data_845_082_0044.npz")
t = npz["t"]

In [11]:
num_std = np.asarray(
    [0.02, 0.02, 0.02, 0.02, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]
)

In [12]:
start_time = time.time()
all_samples = []
for k in range(243, 2000):
    start_time = time.time()
    num_samples_per_iter = 2520
    flow_samples = flow.sample(num_samples_per_iter)
    seeds = torch.randint(0, 10000, (num_samples_per_iter, 1))
    
    params_with_seeds = torch.cat((flow_samples, seeds), axis=1).detach().numpy()

    def simulator(params_set):
        out_target = simulate(
            deepcopy(params_set[:-1].astype(np.float64)),
            seed=int(params_set[-1]),
        )
        return stats(out_target)

    simulation_outputs = Parallel(n_jobs=12)(
        delayed(simulator)(batch)
        for batch in params_with_seeds
    )

    simulation_outputs = np.asarray(simulation_outputs)
    
    # Just for fun and print: compute how many are good.
    good_params_new, good_dat_new, good_seeds_new = select_ss_close_to_obs(
        flow_samples,
        simulation_outputs,
        seeds.squeeze(),
        observation,
        num_std=num_std,
        stats_std=stats_std[:15],
        new_burst_position_in_ss=True
    )
    
    print("num of good:  ", good_params_new.shape)

    np.savez(f"../../results/11deg_post_pred/11deg_5million_predictives_for_temp/simulated/11deg_5million_predictives_for_temp_{k}.npz", params=flow_samples.detach().numpy(), stats=simulation_outputs, seeds=seeds)
    print("Overall time for iteration", k, ":   ", time.time()-start_time)

num of good:   torch.Size([112, 31])
Overall time for iteration 215 :    253.0445351600647
num of good:   torch.Size([102, 31])
Overall time for iteration 216 :    250.82126116752625
num of good:   torch.Size([123, 31])
Overall time for iteration 217 :    251.69200897216797
num of good:   torch.Size([128, 31])
Overall time for iteration 218 :    253.2879776954651
num of good:   torch.Size([118, 31])
Overall time for iteration 219 :    255.88742065429688
num of good:   torch.Size([114, 31])
Overall time for iteration 220 :    257.7573673725128
num of good:   torch.Size([91, 31])
Overall time for iteration 221 :    256.9001293182373
num of good:   torch.Size([109, 31])
Overall time for iteration 222 :    259.2209930419922
num of good:   torch.Size([112, 31])
Overall time for iteration 223 :    259.8254613876343
num of good:   torch.Size([118, 31])
Overall time for iteration 224 :    254.41491842269897
num of good:   torch.Size([119, 31])
Overall time for iteration 225 :    254.6636228561

KeyboardInterrupt: 