In [1]:
import choice_probabilities_mcmc as cpmcmc
import numpy as np
import seaborn as sns
import scipy.stats as scps
import matplotlib.pyplot as plt
import pandas as pd
import statsmodels as statm
from datetime import datetime
from statsmodels.tsa.stattools import acf
import os
import csv
import glob

In [None]:
cl = cpmcmc.choice_probabilities_analytic_mh()

In [None]:
# Model parameters
cl.model_num = 0
cl.model_time = '09_03_18_17_28_21'
cl.model_signature = '_choice_probabilities_analytic_'
cl.model_checkpoint = 'final'

In [None]:
# Data simulation parameters
cl.data_sim_params['v'] = 0.0
cl.data_sim_params['a'] = 2
cl.data_sim_params['n_samples'] = 5000

In [None]:
# Make paths
cl.make_checkpoint_path()
cl.make_model_path()
cl.model_path

In [None]:
# Make dataset
cl.make_data_set()

In [None]:
# Get predictor
cl.get_dnn_keras()

In [None]:
cl.mcmc_params['n_samples'] = 100000
cl.mcmc_params['cov_init'] = np.array([[0.1, 0.0], [0, 0.1]])
my_chain, acc_samples = cl.metropolis_hastings(method = 'dnn',
                                               variance_scale_param = 0.4,
                                               variance_epsilon = 0.05,
                                               write_to_file = True,
                                               print_steps = False)

In [None]:
cl.chain.loc[cl.chain['log_posterior'].idxmax()]

In [None]:
# Get autocorrelation
chain_autocorrelations = acf(cl.chain['a'], nlags = 80)
n_eff_samples = cl.mcmc_params['n_samples'] / (1 + 2 * np.sum(chain_autocorrelations))

# N effective samples
n_eff_samples

In [None]:
def v_a_curve(x = 0.5, sign = 1):
    curve = pd.DataFrame(np.zeros((999, 2)), columns = ['v', 'a'])
    cnt = 0
    for v_tmp in np.arange(0.01 * sign, 10 * sign, 0.01 * sign):
        a_tmp = np.log((1 - x) / x) / v_tmp
        curve.loc[cnt] = [v_tmp, a_tmp]
        cnt += 1
    return curve

def v_a_curve_prime(x = 0.5, v_star = 1):
    curve_prime = pd.DataFrame(np.zeros((999, 2)), columns = ['v', 'a'])
    cnt = 0
    for v_tmp in np.arange(0.01 * np.sign(v_star), 10 * np.sign(v_star), 0.01 * np.sign(v_star)):
        a_tmp = - (np.log((1 - x) / x) / np.power(v_star, 2)) * v_tmp + (2 * (np.log((1-x)/x)) / v_star)
        curve_prime.loc[cnt] = [v_tmp, a_tmp]
        cnt += 1
        
    return curve_prime

In [None]:
chain_nn = cl.chain.copy()
chain_nn['id'] = chain_nn.index.get_values()

curve = v_a_curve(x = cl.data_sim['n_choice_lower'] / cl.data_sim['n_samples'], 
                  sign = -1)
curve = curve.loc[curve['a'] < np.max(chain_nn['a'])].copy()
curve = curve.loc[curve['a'] > np.min(chain_nn['a'])].copy()
curve = curve.loc[curve['v'] < np.max(chain_nn['v'])].copy()
curve = curve.loc[curve['v'] > np.min(chain_nn['v'])].copy()


#curve_prime = v_a_curve_prime(x = cl.data_sim['n_choice_lower'] / cl.data_sim['n_samples'], v_star = 1)
#curve_prime = curve_prime.loc[curve_prime['a'] < np.max(chain_nn['a'])].copy()
#curve_prime = curve_prime.loc[curve_prime['v'] < np.max(chain_nn['v'])].copy()

g = sns.jointplot('v', 'a', data = chain_nn, kind = 'kde', space = 0, color = 'g')
x0, x1 = g.ax_joint.get_xlim()
y0, y1 = g.ax_joint.get_ylim()
lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(curve['v'], curve['a'], 'r-')
#g.ax_joint.plot(curve_prime['v'], curve_prime['a'], 'b-')
plt.show()

In [None]:
# trace plot v
ax = sns.lineplot(x = 'id', y = 'v', data = chain_nn)

In [None]:
# trace plot a
ax = sns.lineplot(x = 'id', y = 'a', data = chain_nn)

In [None]:
# Now sample from actual model
cl.priors = {'v': scps.norm(loc = 0, scale = 10),
             'a': scps.uniform(loc = 0, scale = 10)}

cl.mcmc_params['n_samples'] = 50000
cl.mcmc_params['cov_init'] = np.array([[0.1, 0.0], [0, 0.1]])
cl.metropolis_hastings(method = 'wfpt')

In [None]:
chain_wfpt = cl.chain.copy()
chain_wfpt['id'] = chain_wfpt.index.get_values()

curve = v_a_curve(x = cl.data_sim['n_choice_lower'] / cl.data_sim['n_samples'])
curve = curve.loc[curve['a'] < np.max(chain_wfpt['a'])].copy()
curve = curve.loc[curve['v'] < np.max(chain_wfpt['v'])].copy()

# Drawing posterior plot with 
g = sns.jointplot('v', 'a', data = chain_wfpt, kind = 'kde', space = 0, color = 'g')
x0, x1 = g.ax_joint.get_xlim()
y0, y1 = g.ax_joint.get_ylim()
lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(curve['v'], curve['a'], 'r-')
plt.show()

In [None]:
cl.chain.shape

In [None]:
# trace plot v
ax = sns.lineplot(x = 'id', y = 'v', data = chain_wfpt)

In [None]:
# trace plot a
ax = sns.lineplot(x = 'id', y = 'a', data = chain_wfpt)

In [None]:
cl.chain.loc[cl.chain['log_posterior'].idxmax()][]

In [None]:
my_chain.loc[3]

In [2]:
# Run experiment: Parameter recovery with MAP for DNN vs. NF_Likelihood

# Make sampler instance
cl2 = cpmcmc.choice_probabilities_analytic_mh()

# Model parameters
cl2.model_num = 0
cl2.model_time = '09_03_18_17_28_21'
cl2.model_signature = '_choice_probabilities_analytic_'
cl2.model_checkpoint = 'final'

# Make paths
cl2.make_checkpoint_path()
cl2.make_model_path()
cl2.model_path

# Attach DNN
cl2.get_dnn_keras()

In [3]:
# Main experiment 

# Open directory for experiment data
cwd = os.getcwd()
exp_dir = cwd + '/experiments/bayesian_comparison_dnn_nf_choice_probability_analytic_' \
          + datetime.now().strftime('%m_%d_%y_%H_%M_%S')
os.mkdir(exp_dir)

# Main specification of experiment parameters
n_experiments = 120
cl2.data_sim_params['n_samples'] = 7000
cl2.mcmc_params['n_samples'] = 100000

model_types = ['dnn', 'wfpt']

# Storage data
columns = ['experiment_id', 
           'data_v', 
           'data_a',
           'dnn_n_eff_samples',
           'dnn_map_loglik', 
           'dnn_map_a', 
           'dnn_map_v', 
           'nf_n_eff_samples',
           'nf_map_loglik',
           'nf_map_a',
           'nf_map_v', 
           ]

cnt = 0
if cnt == 0:
    with open(exp_dir + '/exp_data.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(columns)


exp_data = pd.DataFrame(np.zeros((n_experiments, 11)), 
                        columns = columns)



# Main Loop
data_id = 0
while cnt < n_experiments:
    # Sample parameters for simulation
    v_tmp = np.random.uniform(low = -2, high = 2)
    a_tmp = np.random.uniform(low = 0.5, high = 3)

    # Print info:
    print('Experiment: ', cnt)
    print('v: ', v_tmp)
    print('a: ', a_tmp)

    # Data simulation parameters
    cl2.data_sim_params['v'] = v_tmp
    cl2.data_sim_params['a'] = a_tmp

    # Make dataset
    cl2.make_data_set()

    chain_dnn, _ = cl2.metropolis_hastings_custom(method = 'dnn',
                                                  variance_scale_param = 0.4,
                                                  variance_epsilon = 0.05,
                                                  write_to_file = True,
                                                  print_steps = False)

    chain_wfpt, _ = cl2.metropolis_hastings_custom(method = 'wfpt',
                                                   variance_scale_param = 0.4,
                                                   variance_epsilon = 0.05,
                                                   write_to_file = True,
                                                   print_steps = False)

    # Get number of effective samples
    # ------------------------------------------

    # dnn
    chain_autocorrelations_dnn = acf(chain_dnn['a'], nlags = 80)
    n_eff_samples_dnn = cl2.mcmc_params['n_samples'] / (1 + 2 * np.sum(chain_autocorrelations_dnn))

    # wfpt
    chain_autocorrelations_wfpt = acf(chain_wfpt['a'], nlags = 80)
    n_eff_samples_wfpt = cl2.mcmc_params['n_samples'] / (1 + 2 * np.sum(chain_autocorrelations_wfpt))


    # ------------------------------------------

    # Compute map
    # ------------------------------------------

    # dnn
    map_dnn = chain_dnn.loc[chain_dnn['log_posterior'].idxmax()]

    # wfpt
    map_wfpt = chain_wfpt.loc[chain_wfpt['log_posterior'].idxmax()]

    # ------------------------------------------

    # Store data
    # ------------------------------------------
    with open(exp_dir + '/exp_data.csv', 'a') as f:
        writer = csv.writer(f)
        writer.writerow([int(cnt), 
                         cl2.data_sim_params['v'], 
                         cl2.data_sim_params['a'], 
                         n_eff_samples_dnn, 
                         map_dnn['log_posterior'], 
                         map_dnn['a'],
                         map_dnn['v'],
                         n_eff_samples_wfpt,
                         map_wfpt['log_posterior'],
                         map_wfpt['a'],
                         map_wfpt['v']])
    # ------------------------------------------
    chain_dnn.to_csv(exp_dir + '/chain_dnn_' + str(int(cnt)) + '.csv')
    chain_wfpt.to_csv(exp_dir + '/chain_wfpt_' + str(int(cnt)) + '.csv')
    cl2.data_sim['data'].to_csv(exp_dir + '/data_' + str(int(cnt)) + '.csv')
    cnt += 1

Experiment:  0
v:  -1.7028345588824223
a:  1.821313525781874
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0


  alpha = np.exp(log_alpha)


1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71

75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  6
v:  -0.8813123592335756
a:  1.3034418557145768
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 

50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  11
v:  1.6451438813842532
a:  2.038988414402545
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000


datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
340

38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  22
v:  -0.4988090877915812
a

13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  27
v:  -1.1376286132132356
a:  1.0459269364708779
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
1800

85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  32
v:  -0.8019440639557156
a:  0.9917349580842525
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
9000

label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
670

71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  43
v:  -1.618687025354772
a:  1.3745740792439085
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label

45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  48
v:  1.724563967309658
a:  2.683486417760568
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
5

datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
2

30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000

3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  64
v:  1.0890905110999678
a:  1.4477575208124513
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000


76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  69
v:  0.4471364679574461
a:  1.7273516914722102
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000

label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000


61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  80
v:  -0.8815198986179866
a:  2.986542447259166
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 gener

36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  85
v:  1.9409161477715804
a:  2.0816685038968648
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000

datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000


26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000

98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  101
v:  -0.23376149843210792
a:  0.8390781105740479
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
400

71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  106
v:  -1.3059961545941956
a:  2.395166000521984
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 3000 generated
datapoint 4000 generated
datapoint 5000 generated
datapoint 6000 generated
label 0 generated
label 1000 generated
label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
7600

label 2000 generated
label 3000 generated
label 4000 generated
label 5000 generated
label 6000 generated
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
530

56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
Experiment:  117
v:  -0.17772320544287945
a:  2.4328227752661276
datapoint 0 generated
datapoint 1000 generated
datapoint 2000 generated
datapoint 30

In [None]:
# VISUALIZE SAMPLER OUTCOMES:
cwd = os.getcwd()
exp_dir = glob.glob(cwd + '/experiments/bayesian_comparison_dnn_nf_choice_probability_analytic_*')[0]

In [None]:
# Get data
map_data = pd.read_csv(exp_dir + '/exp_data.csv')


In [None]:
# 2-d plots
f, axes = plt.subplots(1, 2, figsize = (15, 10), sharex = False)
sns.set(style = "white", palette = "muted", color_codes = True, )
sns.despine()

axes[0].plot([-.5, 4], [-0.5, 4], color = 'black', alpha = 0.3)
g1 = sns.scatterplot(x = 'dnn_map_a', y = 'nf_map_a', data = map_data, ax = axes[0], marker = '+', s = 100)
g1.set(xlim = (-0.5, 4), ylim = (-0.5, 4))
axes[0].set_xlabel('MAP-NN: a')
axes[0].set_ylabel('MAP-TRUE: a')

axes[1].plot([-3, 3], [-3, 3], color = 'black', alpha = 0.3)
g2 = sns.scatterplot(x = 'dnn_map_v', y = 'nf_map_v', data = map_data, ax = axes[1], marker = '+', s = 100, hue = 'data_v')
g2.set(xlim = (-3, 3), ylim = (-3, 3))
axes[1].set_xlabel('MAP-NN: v')
axes[1].set_ylabel('MAP-TRUE: v')

In [None]:
# INVESTIGATE SUSPICIOUS DATA-POINTS
map_data['map_diff_a'] = np.abs(map_data['dnn_map_a'] - map_data['nf_map_a'])
map_data['map_diff_v'] = np.abs(map_data['dnn_map_v'] - map_data['nf_map_v'])

In [None]:
map_data.sort_values(by = ['map_diff_v'], ascending = False)

In [None]:
# Conclusions on suspicious v's

In [None]:
# Suspicious a's 
map_data['data_dnn_diff_a'] = np.abs(map_data['data_a'] - map_data['dnn_map_a'])
map_data['data_nf_diff_a'] = np.abs(map_data['data_a'] - map_data['nf_map_a'])

In [None]:
map_data.sort_values(by = ['data_dnn_diff_a'], ascending = False)

In [None]:
# It seem that in general, a is much more unstable with respect to parameter recovery. 
# In tendency the problem seems exacerbated when v is hovering around 0, but occurs also with seemingly innocuous parameter
# combinations of v and a ((-2, 2.5) as an example)

# For now we are not concerned about this, because the main issue is to align the behavior of the Neural Network with the 
# behavior of the Navarro and Fuss likelihood