In [1]:
import numpy as np
import importlib
import matplotlib as mpl
import matplotlib.pylab as plt
import time
from copy import deepcopy
import torch

from stg_energy.common import col, svg, samples_nd, get_labels_8pt
from stg_energy.fig5_cc.viz import vis_sample_plain
from pyloric.sbi_prior import create_prior, create_prior_general
from pyloric.sbi_wrapper import simulate, simulate_general, load_setup, get_time, stats
from stg_energy.fig7_temp.process_samples import merge_samples
import stg_energy.fig7_temp.viz

from stg_energy.fig2_histograms.energy import select_ss_close_to_obs
from parameters import ParameterSet

from joblib import Parallel, delayed

### Load the initial 55,000 samples that were close to the observation

In [2]:
data = np.load("../../results/11deg_post_pred/11_deg_post_pred_close_to_obs.npz")
good_params = data["sample_params"]
good_stats = data["sample_stats"]
good_seeds = data["sample_seeds"]

### Load the other 5 million simulations at 11 degree. We then search these for being close to the observation

In [5]:
datafile = "../../results/prior_samples_after_classifier/samples_full_3.npz"
data = np.load(datafile)
ss_prior = data["stats"]

stats_mean = np.mean(ss_prior, axis=0)
stats_std = np.std(ss_prior, axis=0)

In [6]:
npz = np.load("../../results/experimental_data/summstats_prep845_082_0044.npz")
observation = npz["summ_stats"]

npz = np.load("../../results/experimental_data/trace_data_845_082_0044.npz")
t = npz["t"]

In [7]:
num_std = np.asarray(
    [0.02, 0.02, 0.02, 0.02, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]
)

In [None]:
all_good_params = [good_params]
all_good_stats = [good_stats]
all_good_seeds = [good_seeds]

total_num_close_to_obs = 0

for k in range(2000):
    data = np.load(f"../../results/11deg_post_pred/11deg_5million_predictives_for_temp/simulated/11deg_5million_predictives_for_temp_{k}.npz")
    params_5million = data['params']
    stats_5million = data['stats']
    seeds_5million = data['seeds']

    good_params_new, good_dat_new, good_seeds_new = select_ss_close_to_obs(
        params_5million,
        stats_5million,
        seeds_5million,
        observation,
        num_std=num_std,
        stats_std=stats_std[:15],
    )
    if k % 200 == 0: 
        print("Num_of_good: ", params_5million.shape[0])
        
    total_num_close_to_obs += params_5million.shape[0]

    np.savez(f"../../results/11deg_post_pred/11deg_5million_predictives_for_temp/close_to_obs/11deg_5million_predictives_for_temp_{k}.npz", params=good_params_new, stats=good_new_dat, seeds=good_seeds_new)
    
print("total number of remaining samples:   ", total_num_close_to_obs)

### Simulate all data at 27 degree

In [10]:
setups_dict = ParameterSet('/home/michael/Documents/pyloric/pyloric/setups.prm')
hyperparams_11 = setups_dict['collect_samples_15deg_energy_ssRanges']
general_prior = create_prior_general(hyperparams_11, log=True)

In [13]:
hyperparams_11.model_params.temp = 299

In [None]:
for k in range(2000):
    start_time = time.time()
    
    data = np.load(f"../../results/11deg_post_pred/11deg_5million_predictives_for_temp/close_to_obs/11deg_5million_predictives_for_temp_{k}.npz")
    params_close = data['params']
    stats_close = data['stats']
    seeds_close = data['seeds']
    
    q10s = general_prior.sample((params_close.shape[0],))[:, -10:].detach().numpy()
    params_with_q10s = np.concatenate((params_close, q10s), dim=1)

    seeds_27 = np.random.randint(0, 10000, (params_with_q10s.shape[0], 1))
    params_with_seeds = np.concatenate((params_with_q10s, seeds_27), dim=1)

    def simulator(params_set):
        out_target = simulate_general(
            deepcopy(params_set[:-1].astype(np.float64)),
            hyperparams=hyperparams_11,
            seed=int(params_set[-1]),
        )
        return stats(out_target)

    simulation_outputs = Parallel(n_jobs=12)(
        delayed(simulator)(batch)
        for batch in params_with_seeds
    )

    simulation_outputs = np.asarray(simulation_outputs)

    np.savez(f"../../results/11deg_post_pred/temperature/sbi/201005_5Million_post_pred_close_to_obs_simulated_at_27/set_{k}.npz", params=simulation_outputs, stats_27=simulation_outputs, stats_11=stats_close, seeds=seeds_27)
    print("Overall time for iteration", k, ":   ", time.time()-start_time)

### Select only those simulations that are robust at 27 degree
Why not also check for the Q10 of the cycle frequency? Because the Q10 of tau is anyways fixed. So they will anyways increase in speed.

In [None]:
all_params = []
all_stats_27 = []
all_seeds = []

for k in range(2000):
    start_time = time.time()
    
    data = np.load(f"../../results/11deg_post_pred/temperature/sbi/201005_5Million_post_pred_close_to_obs_simulated_at_27/set_{k}.npz")
    params_close = data['params']
    stats_27 = data['stats_27']
    stats_11 = data['stats_11']
    seeds_close = data['seeds']
    
    condition = np.invert(np.any(np.isnan(stats_close), dim=1))
    
    all_params.append(params_close[condition])
    all_stats_27.append(stats_27[condition])
    all_stats_11.append(stats_11[condition])
    all_seeds.append(seeds_close[condition])

all_params = np.concatenate(all_params)
all_stats = np.concatenate(all_stats)
all_seeds = np.concatenate(all_seeds)

print("Total number of robust samples:  ", all_params.shape[0])
np.savez(f"../../results/11deg_post_pred/temperature/sbi/201005_5Million_post_pred_close_to_obs_simulated_at_27_and_robust.npz", params=all_params, stats=all_stats, seeds=all_seeds)