In [1]:
import numpy as np
import importlib
import matplotlib as mpl
import matplotlib.pylab as plt
import time
from copy import deepcopy
import torch

from stg_energy.common import col, svg, samples_nd, get_labels_8pt
from stg_energy.fig5_cc.viz import vis_sample_plain
from pyloric.sbi_prior import create_prior, create_prior_general
from pyloric.sbi_wrapper import simulate, simulate_general, load_setup, get_time, stats
from stg_energy.fig7_temp.process_samples import merge_samples
import stg_energy.fig7_temp.viz

from stg_energy.fig2_histograms.energy import select_ss_close_to_obs
from parameters import ParameterSet

from joblib import Parallel, delayed

### Load the initial 55,000 samples that were close to the observation

In [2]:
data = np.load("../../results/11deg_post_pred/11_deg_post_pred_close_to_obs.npz")
good_params = data["sample_params"]
good_stats = data["sample_stats"]
good_seeds = data["sample_seeds"]

### Load the other 5 million simulations at 11 degree. We then search these for being close to the observation

In [3]:
datafile = "../../results/prior_samples_after_classifier/samples_full_3.npz"
data = np.load(datafile)
ss_prior = data["stats"]

stats_mean = np.mean(ss_prior, axis=0)
stats_std = np.std(ss_prior, axis=0)

In [4]:
npz = np.load("../../results/experimental_data/summstats_prep845_082_0044.npz")
observation = npz["summ_stats"]

npz = np.load("../../results/experimental_data/trace_data_845_082_0044.npz")
t = npz["t"]

In [5]:
num_std = np.asarray(
    [0.02, 0.02, 0.02, 0.02, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]
)

In [6]:
total_num_close_to_obs = 0

for k in range(100):
    data = np.load(f"../../results/11deg_post_pred/11deg_5million_predictives_for_temp/simulated/11deg_5million_predictives_for_temp_{k}.npz")
    params_5million = data['params']
    stats_5million = data['stats']
    #seeds_5million = data['seeds']
    seeds_5million = torch.zeros(params_5million.shape[0]).numpy()

    good_params_new, good_dat_new, good_seeds_new = select_ss_close_to_obs(
        params_5million,
        stats_5million,
        seeds_5million,
        observation,
        num_std=num_std,
        stats_std=stats_std[:15],
        new_burst_position_in_ss=True
    )
    if k % 200 == 0: 
        print("Num_of_good: ", good_params_new.shape[0])
        
    total_num_close_to_obs += good_params_new.shape[0]

    np.savez(f"../../results/11deg_post_pred/11deg_5million_predictives_for_temp/close_to_obs/11deg_5million_predictives_for_temp_{k}.npz", 
             params=good_params_new, stats=good_dat_new, seeds=good_seeds_new)
    
print("total number of accepted samples:   ", total_num_close_to_obs)

Num_of_good:  112
total number of accepted samples:    10858


In [7]:
2520*177

446040

### Simulate all data at 27 degree

In [8]:
setups_dict = ParameterSet('/home/michael/Documents/pyloric/pyloric/setups.prm')
hyperparams_11 = setups_dict['collect_samples_15deg_energy_ssRanges']
general_prior = create_prior_general(hyperparams_11, log=True)

In [9]:
hyperparams_11.model_params.temp = 299

In [10]:
for k in range(100):
    start_time = time.time()
    
    data = np.load(f"../../results/11deg_post_pred/11deg_5million_predictives_for_temp/close_to_obs/11deg_5million_predictives_for_temp_{k}.npz")
    params_close = data['params']
    stats_close = data['stats']
    seeds_close = data['seeds']
    
    q10s = general_prior.sample((params_close.shape[0],))[:, -10:].detach().numpy()
    params_with_q10s = np.concatenate((params_close, q10s), axis=1)

    seeds_27 = np.random.randint(0, 10000, (params_with_q10s.shape[0], 1))
    params_with_seeds = np.concatenate((params_with_q10s, seeds_27), axis=1)

    def simulator(params_set):
        out_target = simulate_general(
            deepcopy(params_set[:-1].astype(np.float64)),
            hyperparams=hyperparams_11,
            seed=int(params_set[-1]),
        )
        return stats(out_target)

    simulation_outputs = Parallel(n_jobs=12)(
        delayed(simulator)(batch)
        for batch in params_with_seeds
    )

    simulation_outputs = np.asarray(simulation_outputs)

    np.savez(f"../../results/temperature/sbi/201005_5Million_post_pred_close_to_obs_simulated_at_27/set_{k}.npz", 
             params=params_with_q10s, stats_27=simulation_outputs, stats_11=stats_close, seeds_27=seeds_27, seeds_11=seeds_close)
    print("Overall time for iteration", k, ":   ", time.time()-start_time)

Overall time for iteration 0 :    14.36601734161377
Overall time for iteration 1 :    13.050976037979126
Overall time for iteration 2 :    10.827553272247314
Overall time for iteration 3 :    12.675790786743164
Overall time for iteration 4 :    11.464680194854736
Overall time for iteration 5 :    12.338057279586792
Overall time for iteration 6 :    12.68315315246582
Overall time for iteration 7 :    13.68934416770935
Overall time for iteration 8 :    12.336710691452026
Overall time for iteration 9 :    11.930693626403809
Overall time for iteration 10 :    13.57594609260559
Overall time for iteration 11 :    13.525963068008423
Overall time for iteration 12 :    12.103846788406372
Overall time for iteration 13 :    10.831341981887817
Overall time for iteration 14 :    11.307703256607056
Overall time for iteration 15 :    10.80772590637207
Overall time for iteration 16 :    12.501270532608032
Overall time for iteration 17 :    11.194868087768555
Overall time for iteration 18 :    15.07003

### Select only those simulations that are robust at 27 degree
Why not also check for the Q10 of the cycle frequency? Because the Q10 of tau is anyways fixed. So they will anyways increase in speed.

In [12]:
all_params = []
all_stats_27 = []
all_stats_11 = []
all_seeds_27 = []
all_seeds_11 = []

for k in range(100):
    start_time = time.time()

    data = np.load(f"../../results/temperature/sbi/201005_5Million_post_pred_close_to_obs_simulated_at_27/set_{k}.npz")
    params_close = data['params']
    stats_27 = data['stats_27']
    stats_11 = data['stats_11']
    seeds_27 = data['seeds_27']
    seeds_11 = data['seeds_11']

    condition = np.invert(np.any(np.isnan(stats_27), axis=1))

    all_params.append(params_close[condition])
    all_stats_27.append(stats_27[condition])
    all_stats_11.append(stats_11[condition])
    all_seeds_27.append(seeds_27[condition])
    all_seeds_11.append(seeds_11[condition])

all_params = np.concatenate(all_params)
all_stats_27 = np.concatenate(all_stats_27)
all_stats_11 = np.concatenate(all_stats_11)
all_seeds_27 = np.concatenate(all_seeds_27)
all_seeds_11 = np.concatenate(all_seeds_11)

print("Total number of robust samples:  ", all_params.shape[0])
np.savez(f"../../results/temperature/sbi/201005_5Million_post_pred_close_to_obs_simulated_at_27_and_robust.npz", 
         params=all_params, stats_27=all_stats_27, stats_11=all_stats_11, seeds_27=all_seeds_27, seeds_11=all_seeds_11)

Total number of robust samples:   2569
