# Notebook to identify start and end point
This notebook finds parameter sets that are far apart in parameter space but produce similar activity.  
This notebook is somewhat based on visual inspection (similarity of traces).

Eventually, it produces 8 pairs of parameter sets. We used the forth one (starting from index 0, it is -> 3). This parameter set is saved in results/31D_pairs/similar_and_good/sample_pair_3.npz

In [None]:
import numpy as np
import matplotlib.pylab as plt
import delfi.distribution as dd
import time
from copy import deepcopy
import sys
sys.path.append("model/setup")
sys.path.append("model/simulator")
sys.path.append("model/inference")
sys.path.append("model/visualization")
sys.path.append("model/utils")
sys.path.append("../")

import netio
import viz
import importlib
import viz_samples_start_end_point as viz_samples_thesis
import train_utils as tu
import startEndUtils as seu
from find_pyloric import merge_samples, params_are_bounded
import dill as pickle
import matplotlib as mpl

from common import col, svg, samples_nd

PANEL_A = 'illustration/panel_a.svg'
PANEL_B = 'svg/panel_b.svg'
PANEL_C = 'svg/panel_c.svg'
PANEL_D = 'svg/panel_d.svg'
PANEL_E = 'svg/panel_e.svg'

%load_ext autoreload
%autoreload 2

In [None]:
params = netio.load_setup("train_31D_R1_BigPaper")

In [None]:
with open('results/31D_nets/191001_seed1_Exper11deg.pkl', 'rb') as file:
    inf_SNPE_MAF, log, _ = pickle.load(file)

In [None]:
prior = netio.create_prior(params, log=True)
dimensions = np.sum(params.use_membrane) + 7
lims = np.asarray([-np.sqrt(3)*np.ones(dimensions), np.sqrt(3)*np.ones(dimensions)]).T

# Load data

In [None]:
filedir = "results/31D_samples/pyloricsamples_31D_noNaN_3.npz"
pilot_data, trn_data, params_mean, params_std = tu.load_trn_data_normalize(filedir, params)
print('We use', len(trn_data[0]), 'training samples.')

stats = pilot_data[1]
stats_mean = np.mean(stats, axis=0)
stats_std  = np.std(stats, axis=0)

In [None]:
prior = netio.create_prior(params, log=True)
params_mean = prior.mean
params_std = prior.std

In [None]:
from find_pyloric import merge_samples, params_are_bounded

labels_ = viz.get_labels(params)
prior_normalized = dd.Uniform(-np.sqrt(3)*np.ones(dimensions), np.sqrt(3)*np.ones(dimensions), seed=params.seed)

In [None]:
summstats_experimental = np.load('results/31D_experimental/190807_summstats_prep845_082_0044.npz')['summ_stats']

In [None]:
from find_pyloric import merge_samples, params_are_bounded

prior_normalized = dd.Uniform(-np.sqrt(3)*np.ones(dimensions), np.sqrt(3)*np.ones(dimensions), seed=params.seed)

target = summstats_experimental
posterior_MAF = inf_SNPE_MAF.predict([target]) # given the current sample, we now predict the posterior given our simulation outcome. Note that this could just be overfitted.

In [None]:
samples_MAF_11 = merge_samples("results/31D_samples/02_cond_vals", name='conductance_params')
samples_MAF_11 = np.reshape(samples_MAF_11, (1000*2520, 31))
print(np.shape(samples_MAF_11))

# Create database of (theta, x)
### Start points

In [None]:
from copy import deepcopy

pyloric_sim = netio.create_simulators(params)
summ_stats = netio.create_summstats(params)

offsets = [50000,50000, 50000,50000, 50000,50000, 50000,50000, 50000,50000, 50000,50000]
exp_stds =  np.asarray([279,  133, 113, 150, 109, 60,  169, 216,  0.040, 0.059, 0.054, 0.065, 0.034,  0.054, 0.060])
print('stats_std', stats_std)
exp_stds = stats_std

indizes = range(25)

all_zero5_params = []
all_zero5_stats  = []
all_one_params = []
all_one_stats  = []

counter = 0
start_time = time.time()
for ind in indizes:
    if ind % 100 == 0: print('---- Index:', ind, '----')
    target_params = samples_MAF[ind]
    target = target_params
    
    if (target[24] > 2.1-np.sqrt(3) and target[24] < 2.5-np.sqrt(3)) and (target[25] > 0.35-np.sqrt(3) and target[25] < 0.75-np.sqrt(3)):
        target_params = target_params * params_std + params_mean
        out_target = pyloric_sim[0].gen_single(deepcopy(target_params), seed_sim=True, to_seed=1) # params.true_params gives the synaptic strengths #  165000
        ss = summ_stats.calc([out_target])[0]
        ss_diff = np.abs(summstats_experimental[:15] - ss[:15]) / exp_stds[:15]
        if np.all(ss_diff < 0.1):
            all_one_params.append(target_params)
            all_one_stats.append(ss)
            print('Found 0.1 std diff')

np.savez('../results/pairs/31D_paper/all_similar_to_obs/sample_params_start.npz', params=all_one_params, summ_stats=all_one_stats)
print('Overall time:  ', time.time()-start_time)

### End point

In [None]:
import importlib
from copy import deepcopy
importlib.reload(mpl); importlib.reload(plt); importlib.reload(sns)
import matplotlib.gridspec as gridspec

pyloric_sim = netio.create_simulators(params)
summ_stats = netio.create_summstats(params)

offsets = [50000,50000, 50000,50000, 50000,50000, 50000,50000, 50000,50000, 50000,50000]
exp_stds =  np.asarray([279,  133, 113, 150, 109, 60,  169, 216,  0.040, 0.059, 0.054, 0.065, 0.034,  0.054, 0.060])
print('stats_std', stats_std)
exp_stds = stats_std

indizes = range(50)

all_zero5_params = []
all_zero5_stats  = []
all_one_params = []
all_one_stats  = []

counter = 0
start_time = time.time()
for ind in indizes:
    if ind % 1000 == 0: print('---- Index:', ind, '----')
    target_params = samples_MAF[ind]
    target = target_params
    
    if (target[24] > 0.9-np.sqrt(3) and target[24] < 1.3-np.sqrt(3)) and (target[25] > 1.75-np.sqrt(3) and target[25] < 2.18-np.sqrt(3)):
        target_params = target_params * params_std + params_mean
        out_target = pyloric_sim[0].gen_single(deepcopy(target_params), seed_sim=True, to_seed=1) # params.true_params gives the synaptic strengths #  165000
        ss = summ_stats.calc([out_target])[0]
        ss_diff = np.abs(summstats_experimental[:15] - ss[:15]) / exp_stds[:15]
        if np.all(ss_diff < 0.1):
            all_one_params.append(target_params)
            all_one_stats.append(ss)
            print('Found 0.1 std diff')

np.savez('results/31D_pairs/all_similar_to_obs/sample_params_end.npz', params=all_one_params, summ_stats=all_one_stats)
print('Overall time:  ', time.time()-start_time)

### Search for samples with similar activity but disparate parameters

In [None]:
npz = np.load('results/31D_pairs/all_similar_to_obs/sample_params_start.npz')
sample_params_start = npz['params']
sample_stats_start  = npz['summ_stats']
npz = np.load('results/31D_pairs/all_similar_to_obs/sample_params_end.npz')
sample_params_end = npz['params']
sample_stats_end  = npz['summ_stats']

In [None]:
unnorm_s = (sample_params_start - params_mean) / params_std

In [None]:
start_baselines = 0
number_baselines = len(sample_stats_start)
number_comparisons = len(sample_stats_end)

all_index1 = []
all_index2 = []
all_num_diff = []
all_num_diff_membrane = []
all_num_diff_syn = []

margin = 0.1

for baseline_num in range(start_baselines, start_baselines+number_baselines):
    
    baseline_sample = deepcopy(sample_stats_start[baseline_num])
    baseline_params = deepcopy(sample_params_start[baseline_num])
        
    for compare_sample_num in range(number_comparisons):
        current_sample = sample_stats_end[compare_sample_num]
        current_params = sample_params_end[compare_sample_num]
        
        if seu.check_equality(baseline_sample, current_sample, margin=margin, stats_std=stats_std, mode='dataset'):
            all_index1, all_index2 = seu.check_num_different_conds(baseline_params, current_params, all_index1, all_index2)
outfile = 'results/31D_pairs/similar_to_each_other/sample_pair'
np.savez_compressed(outfile, params1=all_index1, params2=all_index2)
print('--- Finished successfully ---')

### Load and display them

In [None]:
npz = np.load('results/31D_pairs/similar_to_each_other/sample_pair.npz')
index1 = npz['params1']
index2 = npz['params2']
print(len(index1))
params = netio.load_setup('train_31D_R1_BigPaper')

In [None]:
counter = 0

#for pair_num in range(len(index1)):
for pair_num in [2,3,6,17,21,24,27,28]:
    print('Novel pair', pair_num)
    params1 = index1[pair_num]
    params2 = index2[pair_num]
    
    target_params = params1
    out_target = pyloric_sim[0].gen_single(deepcopy(target_params), seed_sim=True, to_seed=418011) # params.true_params gives the synaptic strengths #  165000

    fig = viz_samples_thesis.vis_sample(pyloric_sim[0], summ_stats, target_params, voltage_trace=out_target, test_idx=[0], case='high_p', hyperparams=params, scale_bar=False, vis_legend=False, offset_labels=1000, with_ss=False, time_len=165000, fontscale=1.2, linescale=1.2, legend=False, offset=20000,
                                 mode='31D', mem_dimensions=[0,1,8,14,19,21], title='Sample along the path of high probability in Prinz format', date_today='190705_posterior_samples_experimental', multiplier_cond_shift=80, mode2='small', counter=0, save_fig=False)
    plt.show()
    
    target_params = params2
    out_target = pyloric_sim[0].gen_single(deepcopy(target_params), seed_sim=True, to_seed=418011) # params.true_params gives the synaptic strengths #  165000

    fig = viz_samples_thesis.vis_sample(pyloric_sim[0], summ_stats, target_params, voltage_trace=out_target, test_idx=[0], case='high_p', hyperparams=params, scale_bar=False, vis_legend=False, offset_labels=1000, with_ss=False, time_len=165000, fontscale=1.2, linescale=1.2, legend=False, offset=20000,
                                 mode='31D', mem_dimensions=[0,1,8,14,19,21], title='Sample along the path of high probability in Prinz format', date_today='190705_posterior_samples_experimental', multiplier_cond_shift=80, mode2='small', counter=0, save_fig=False)
    plt.show()
    
    outfile = 'results/31D_pairs/similar_and_good/sample_pair_{}'.format(counter)
    np.savez_compressed(outfile, params1=(params1-params_mean)/params_std, params2=(params2-params_mean)/params_std)
    
    counter += 1