In [None]:
from matplotlib import pyplot as plt
import numpy as np
from scipy.stats import entropy
# %matplotlib inline
# %matplotlib notebook

In [None]:
plt.rcParams.update({'font.size': 14})
plt.rcParams['figure.figsize'] = 20, 10

# Harmonic Oscillator

In [None]:
num_observations = np.array([1, 2, 5, 10, 25, 50])
true_observational_variance = np.array([0.0125, 0.025, 0.05]) # 0.1/6 is true
start_time = np.array([.25, 0.5, 1])
obs_len = np.array([1, 2 , 3, 4])

In [None]:
Dg = np.load("savefiles/HO_savefile_gridsearch_allruns_fine.npy")
# Make the plot
def grid(nbins=5, mins=np.zeros(1), maxs=np.ones(1) ):
    dim = len(mins)
    S = np.array( np.meshgrid(*[np.linspace(i,j,nbins) for i,j in zip(mins+1./(2*nbins), maxs-1./(2*nbins))]) )
    SS = S.ravel().reshape(dim,nbins**dim).T
    return SS

nbins = 250
mins = np.array([-0.25, .5]) # original domain
maxs = np.array([0.25, 1.5])
X = grid(nbins, mins, maxs)

num_obs = 25
error = 0.025
start_time = 1
observation_len = 4

# num_obs = 5
# error = 0.0125
# start_time = 1
# observation_len = 1

xi = X[:,0].reshape(nbins, nbins)
yi = X[:,1].reshape(nbins, nbins) 
for run_num in range(len(Dg)):
    post_eval = Dg[run_num]['p']
    if np.sum(np.abs(np.array(Dg[run_num]['t']) - np.array([num_obs, error, start_time, observation_len]))) == 0:
        zi = post_eval.reshape(nbins, nbins)
        plt.cla()
        plt.pcolormesh(xi, yi, zi)
        # plt.pcolormesh(zi)
        plt.scatter([0], [1], c='white', edgecolor='black', s=100)
        plt.show()

In [None]:
plt.savefig('saveimgs/ho.eps')

# Exponential Decay

In [None]:
num_observations = np.array([1, 2, 5, 10, 25, 50])
true_observational_variance = np.array([0.0125, 0.025, 0.05]) 
start_time = np.array([0.25, 0.5, 0.75, 1])
obs_len = np.array([1, 2 , 3, 4])

In [None]:
Dg = np.load("savefiles/ED_savefile_gridsearch_allruns_fine.npy")
# Make the plot
def grid(nbins=5, mins=np.zeros(1), maxs=np.ones(1) ):
    dim = len(mins)
    S = np.array( np.meshgrid(*[np.linspace(i,j,nbins) for i,j in zip(mins+1./(2*nbins), maxs-1./(2*nbins))]) )
    SS = S.ravel().reshape(dim,nbins**dim).T
    return SS

nbins = 250
mins = np.array([0, 0]) # original domain
maxs = np.array([1, 1])
X = grid(nbins, mins, maxs)

num_obs = 50
error = 0.025
start_time = 0.5
observation_len = 4

# num_obs = 5
# error = 0.0125
# start_time = 1
# observation_len = 1

xi = X[:,0].reshape(nbins, nbins)
yi = X[:,1].reshape(nbins, nbins) 
for run_num in range(len(Dg)):
    post_eval = Dg[run_num]['p']
    if np.sum(np.abs(np.array(Dg[run_num]['t']) - np.array([num_obs, error, start_time, observation_len]))) == 0:
        zi = post_eval.reshape(nbins, nbins)
        plt.cla()
        plt.pcolormesh(xi, yi, zi)
        # plt.pcolormesh(zi)
        plt.scatter([0.5], [0.5], c='white', edgecolor='black', s=100)
        plt.show()

In [None]:
plt.savefig('saveimgs/ed.eps')

## Stats on Stuff


In [None]:
Dg = np.load("savefile_gridsearch_fix_window.npy")
E = []
T = []
T0 = []
fix_time = int(window_length[0])
fix_obs = int(num_observations[1])
fix_sd_true = 0.0125
fix_sd_guess = fix_sd_true
for D in Dg:
    n = int(D['n'])
    sample_err_1 = D['a'][0]
    sample_err_2 = D['a'][1]
    pf_rel_err_mean = D['r'][0]
    pf_rel_err_std = D['r'][1]
    num_obs, sd_true, sd_guess, max_time = D['t']
    if sd_true == fix_sd_true and sd_guess == fix_sd_guess and max_time == fix_time:
        if num_obs == num_observations[0]:
            T0.append([ num_obs, entropy(D['p']), sample_err_1, sample_err_2 ])
            ref = D['p']
        else:
            T0.append([ num_obs, entropy(D['p'],ref), sample_err_1, sample_err_2 ])
#             print(entropy(D['p'], ref))
#             print(pf_rel_err_mean)
            ref = D['p']
    if num_obs == fix_obs and max_time == fix_time:
        T.append([sd_true, sd_guess, pf_rel_err_mean, pf_rel_err_std])
    
    if num_obs == fix_obs and sd_true == fix_sd_true and sd_guess == fix_sd_guess:
        if max_time == window_length[0]:
            E.append([ max_time, entropy(D['p']), sample_err_1, sample_err_2 ])
            ref = D['p']
        else:
            E.append([ max_time, entropy(D['p'],ref), sample_err_1, sample_err_2 ])
            ref = D['p']
            
E = np.array(E)
T = np.array(T)
T0 = np.array(T0)
plt.close('all')
plt.subplots(1,3)

plt.subplot(131)
xi = T[:,0].reshape(len(true_observational_variance),len(perceived_observational_variance))
yi = T[:,1].reshape(len(true_observational_variance),len(perceived_observational_variance))
zi = T[:,3].reshape(len(true_observational_variance),len(perceived_observational_variance))
plt.pcolormesh(xi, yi, zi)
plt.yscale('log')
plt.xscale('log')
plt.ylabel('guess at error model')
plt.xlabel('true error model')
plt.title('error, fix_obs = %d, fix_time = %d'%(fix_obs, fix_time))
plt.savefig('testfig.png')

plt.subplot(132)
plt.cla()
plt.plot(T0[:,0], T0[:,2], label='err in lam_1')
print(np.max(T0[:,2]))
plt.plot(T0[:,0], T0[:,3], label='err in lam_2')
plt.plot(T0[:,0], T0[:,1], label='KL divergence')
plt.plot(num_observations, .1/num_observations, label='MC Rate', c='k')

plt.legend(fontsize=12, framealpha=0.5,loc='upper right')
plt.title('time = %d'%fix_time)
plt.yscale('log')
plt.xscale('log')
plt.ylim([1E-10, 10])
plt.xlabel('num obs')
plt.ylabel('error')

plt.subplot(133)
plt.cla()
plt.plot(E[:,0], E[:,2], label='err in lam_1')
plt.plot(E[:,0], E[:,3], label='err in lam_2')
plt.plot(E[:,0], E[:,1], label='KL divergence')
plt.plot(window_length, .1/window_length, label='MC Rate', c='k')

plt.legend(fontsize=12, framealpha=0.5,loc='upper right')
plt.title('fix_obs = %d, fix_sd = (%.2e, %.2e)'%(fix_obs, fix_sd_true, fix_sd_guess))
plt.yscale('log')
plt.xscale('log')
plt.ylabel('error')
plt.ylim([1E-8, 10])
plt.xlabel('max time')

plt.show()


In [None]:
plt.cla()
plt.plot(T0[:,0], T0[:,2], label='err in lam_1')
plt.plot(T0[:,0], T0[:,3], label='err in lam_2')
plt.plot(T0[:,0], T0[:,1], label='KL divergence')
plt.plot(num_observations, .1/num_observations, label='MC Rate', c='k')

plt.legend(fontsize=12, framealpha=0.5,loc='upper right')
plt.title('time = %d'%fix_time)
plt.yscale('log')
plt.xscale('log')
plt.ylim([1E-8, 10])
plt.xlabel('num obs')
plt.ylabel('error')