In [8]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sparentdir = os.path.dirname(parentdir)
sys.path.insert(0,parentdir)
sys.path.insert(0,sparentdir)

import datetime
import numpy as np
import scipy.stats as stats
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from matplotlib import gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import create
import generate
import recall
import simulations
import simsave
import load_curves

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
import matplotlib as mpl
import copy
import pickle
import numpy as np
import pandas as pd
import time
import seaborn as sns


N = 100
p = 0.5
c = 0.65
r = 0.3
M = 30

set_of_Ss = np.linspace(0,5.0,51)
set_of_Cs = np.linspace(0.4,0.7,13)
s_space = np.linspace(-10.0,15.0,10001)

In [9]:
#Get discretized posterior over S given a pattern and a connection matrix. 
def get_posterior_s(W, x, M, e_s2, frac, prior, s_space):
    
    N = len(x)
    count = (N * (N - 1)) / 2
    
    w_on = np.zeros(np.int(count))
    w_on[:np.int(count*frac)] = 1
    np.random.shuffle(w_on)

    mask = np.zeros((N, N))
    mask[np.triu_indices(N, 1)] = w_on
    W_f = mask * W

    total = 4*np.sum(W_f*(-1)**(x[:, None]+x))
    

    mu_W = total / (count*frac)
    var_W = (M - 1) * e_s2 / (count*frac)
    
    prob_s = stats.norm(mu_W, var_W).pdf(s_space) * prior
    
    return prob_s

In [10]:
#Store Results from Simulation

trials = 2000

recovered_s_control     = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
recovered_s_hlesion     = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

recovered_s_control_lur = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
recovered_s_hlesion_lur = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

error_known             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
error_dual              = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
error_mono              = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

proj_x_true             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
proj_x_known            = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
proj_x_dual             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
proj_x_mono             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

In [11]:
# Define the Background distribution
k_back = 2.0
mu_back = 1.0
background_dist = stats.gamma(k_back, 0, mu_back/k_back)

mu_P0 = mu_back
sigma_P0 = np.sqrt(mu_back*mu_back/k_back)
e_s2_P0 = mu_P0**2 + sigma_P0**2

In [12]:
update_s_known = create.update_s_known()
update_s_frac = create.update_s_frac(mu_P0, 10000, s_factor=1.0, beta=1000, frac=1.0)
frac = 1.0

flat = stats.norm(mu_P0, 10000).pdf
flat_v = flat(s_space) / np.sum(flat(s_space))
prior_s_v = flat_v

start_time = time.time()

for t in range(trials):
    print(str(t) + '; '+str((time.time() - start_time)/60))
    start_time = time.time()

    #Generate the background patterns and strengths
    xs, ss    = generate.memories_gaussian(N, p, M, mu_P0, sigma_P0)
    ss        = background_dist.rvs(N)

    xt = generate.noise(xs[0],r)

    for i_c, c in enumerate(set_of_Cs):

        update_x = create.update_x(lambda snr: c, e_s2 = e_s2_P0, beta=1000)
        
        for i_s, s in enumerate(set_of_Ss):

            ss[0] = s

            W = generate.weight_matrix(xs, ss, 0.5)
            proj_x_true[i_c, i_s, t] = np.sum(xs[0])

            s0 = update_s_frac(W, xt, M, mu_P0)/ np.max([0.15,(1-2*r)**2])

            x_samples, s_samples, _ = recall.maps(W, M, xt, s0, update_x, update_s_frac, samples = 100)
            x_recall                = np.median(x_samples[-50:],axis=0)
            
            post_s_v = get_posterior_s(W, x_recall, M, e_s2_P0, frac, prior_s_v, s_space)
            recovered_s_control[i_c, i_s, t] = s_space[np.argmax(post_s_v)]

            post_s_v = get_posterior_s(W, xt, M, e_s2_P0, frac, prior_s_v, s_space)
            recovered_s_hlesion[i_c, i_s, t] = s_space[np.argmax(post_s_v)]

            error_dual[i_c, i_s, t] = np.sum((xs[0]+x_recall)%2)
            proj_x_dual[i_c, i_s, t] = np.sum(x_recall)


            ### Mono

            x_samples, s_samples, _ = recall.maps(W, M, xt, mu_P0, update_x, update_s_known, samples = 100)
            x_recall                = np.median(x_samples[-50:],axis=0)

            error_mono[i_c, i_s, t] = np.sum((xs[0]+x_recall)%2)
            proj_x_mono[i_c, i_s, t] = np.sum(x_recall)

            ### Known

            x_samples, s_samples, _ = recall.maps(W, M, xt, ss[0], update_x, update_s_known, samples = 100)
            x_recall                = np.median(x_samples[-50:],axis=0)

            error_known[i_c, i_s, t] = np.sum((xs[0]+x_recall)%2)
            proj_x_known[i_c, i_s, t] = np.sum(x_recall)


0; 1.2803077697753907e-05
1; 0.4718387285868327
2; 0.46866677204767865
3; 0.3856357177098592
4; 0.5118104775746664
5; 0.40912609895070395
6; 0.48678927024205526
7; 0.4186014413833618
8; 0.487693993250529
9; 0.4218681772549947
10; 0.3825897256533305
11; 0.41349738438924155
12; 0.4265827536582947
13; 0.4612986167271932
14; 0.3980120658874512
15; 0.3671307404836019
16; 0.5538978219032288
17; 0.36761700312296547
18; 0.4231141209602356
19; 0.37164460817972816
20; 0.3430081526438395
21; 0.4097338716189067
22; 0.403270415465037
23; 0.5850799441337585
24; 0.388702662785848
25; 0.42032207250595094
26; 0.41822070280710855
27; 0.41323577165603637
28; 0.42398364543914796
29; 0.3958686113357544
30; 0.45310867627461754
31; 0.5097835779190063
32; 0.4821464538574219
33; 0.37063784201939903
34; 0.3732147534688314
35; 0.430790913105011
36; 0.4108254988988241
37; 0.40561251242955526
38; 0.3809918244679769
39; 0.4851331869761149
40; 0.4233357032140096
41; 0.443233859539032
42; 0.4735461155573527
43; 0.484

342; 0.4413558006286621
343; 0.387098757425944
344; 0.3929562926292419
345; 0.2917089025179545
346; 0.4307648539543152
347; 0.4302629272143046
348; 0.36528298060099285
349; 0.4138432741165161
350; 0.3721376578013102
351; 0.3743242144584656
352; 0.41202810605367024
353; 0.4033811171849569
354; 0.37423205773035684
355; 0.4533246437708537
356; 0.37072639862696327
357; 0.3812356233596802
358; 0.4054020921389262
359; 0.4488947947820028
360; 0.4736668626467387
361; 0.37836446762084963
362; 0.46109488010406496
363; 0.3686273415883382
364; 0.4163077553113302
365; 0.3683813492457072
366; 0.4128304203351339
367; 0.42068952322006226
368; 0.44081188837687174
369; 0.35557915369669596
370; 0.39931642611821494
371; 0.47388832171758016
372; 0.3904533863067627
373; 0.37839967012405396
374; 0.39754416545232135
375; 0.44062031507492067
376; 0.36606224775314333
377; 0.46497464577356973
378; 0.3510290225346883
379; 0.43846480051676434
380; 0.3646798610687256
381; 0.4020023981730143
382; 0.41206745704015096

679; 0.46242125034332277
680; 0.4815153559048971
681; 0.4343506574630737
682; 0.40346188147862755
683; 0.3402028242746989
684; 0.35955745776494347
685; 0.4611325144767761
686; 0.46888431708017986
687; 0.4477712392807007
688; 0.4133230368296305
689; 0.39069151878356934
690; 0.3380595088005066
691; 0.37060683965682983
692; 0.4399110674858093
693; 0.3932195782661438
694; 0.37296216090520223
695; 0.402900501092275
696; 0.40785882870356244
697; 0.42237171332041423
698; 0.4195855140686035
699; 0.3795738697052002
700; 0.3735086679458618
701; 0.4584349234898885
702; 0.4391951640446981
703; 0.4119004845619202
704; 0.3946425437927246
705; 0.43499403397242226
706; 0.4208776076634725
707; 0.3953389167785645
708; 0.4991860826810201
709; 0.39496378898620604
710; 0.34647157589594524
711; 0.45177980661392214
712; 0.38322673241297406
713; 0.4153739889462789
714; 0.3805923064549764
715; 0.4229309837023417
716; 0.4045832792917887
717; 0.4221341331799825
718; 0.3487984458605448
719; 0.4932897369066874
720

1015; 0.44711257219314576
1016; 0.5292712767918905
1017; 0.4195028066635132
1018; 0.510687013467153
1019; 0.42771485646565754
1020; 0.4418508291244507
1021; 0.416861367225647
1022; 0.45389086405436196
1023; 0.3789628823598226
1024; 0.3878537694613139
1025; 0.41216317812601727
1026; 0.3636562665303548
1027; 0.41768134037653604
1028; 0.4035526514053345
1029; 0.45295116504033406
1030; 0.3875695586204529
1031; 0.4817926367123922
1032; 0.3762409448623657
1033; 0.38675599892934165
1034; 0.39770483573277793
1035; 0.4383177042007446
1036; 0.3691099246342977
1037; 0.4313679218292236
1038; 0.3655800422032674
1039; 0.4659033139546712
1040; 0.5016448577245076
1041; 0.4222176869710286
1042; 0.43844895362854003
1043; 0.3809767683347066
1044; 0.42945401271184286
1045; 0.43364978631337486
1046; 0.4424243172009786
1047; 0.41306609312693277
1048; 0.3752135276794434
1049; 0.40898909568786623
1050; 0.39965091546376547
1051; 0.47901641925175986
1052; 0.4371373017628988
1053; 0.4508318106333415
1054; 0.4038

1339; 0.46948297023773194
1340; 0.37276012102762857
1341; 0.4089037537574768
1342; 0.43299057086308795
1343; 0.4692487875620524
1344; 0.45082377990086875
1345; 0.4849273800849915
1346; 0.3796820799509684
1347; 0.41409846941630046
1348; 0.4316651463508606
1349; 0.45545262495676675
1350; 0.4026072065035502
1351; 0.4097144802411397
1352; 0.45280841588973997
1353; 0.4479187568028768
1354; 0.4480859041213989
1355; 0.3768067677815755
1356; 0.42426061630249023
1357; 0.4376789649327596
1358; 0.37823126713434857
1359; 0.4566054145495097
1360; 0.40528968572616575
1361; 0.4410380244255066
1362; 0.35791253646214805
1363; 0.37376906077067057
1364; 0.40137516260147094
1365; 0.379486083984375
1366; 0.43543986876805624
1367; 0.3867769956588745
1368; 0.35941417614618937
1369; 0.41670748790105183
1370; 0.40275618235270183
1371; 0.38646737734476727
1372; 0.40406445662180585
1373; 0.42856539885203043
1374; 0.4652891000111898
1375; 0.3663382371266683
1376; 0.4424425880114237
1377; 0.4663862983385722
1378; 

1663; 0.3970930814743042
1664; 0.39717466036478677
1665; 0.38935578664143883
1666; 0.41311238606770834
1667; 0.374888531366984
1668; 0.37280990680058795
1669; 0.39315665562947594
1670; 0.43489052454630533
1671; 0.3776468555132548
1672; 0.4604193687438965
1673; 0.3760794043540955
1674; 0.4347539901733398
1675; 0.34989562431971233
1676; 0.4071472962697347
1677; 0.5608251810073852
1678; 0.39449910720189413
1679; 0.4587640523910522
1680; 0.3809356093406677
1681; 0.40062660773595177
1682; 0.4808959007263184
1683; 0.34432993332544964
1684; 0.4232513348261515
1685; 0.4297764658927917
1686; 0.401669446627299
1687; 0.44206862449645995
1688; 0.4688251574834188
1689; 0.43197659651438397
1690; 0.41987046003341677
1691; 0.3669391632080078
1692; 0.4372835874557495
1693; 0.42257538239161174
1694; 0.4161632895469666
1695; 0.4374661167462667
1696; 0.39861178000768027
1697; 0.3979090372721354
1698; 0.36060527165730794
1699; 0.442104967435201
1700; 0.4019728819529215
1701; 0.41529415448506674
1702; 0.426

1987; 0.410726801554362
1988; 0.46903758446375526
1989; 0.38564218680063883
1990; 0.3908641974131266
1991; 0.4723152756690979
1992; 0.3734546422958374
1993; 0.43914703528086346
1994; 0.5062259117762248
1995; 0.42635363737742105
1996; 0.3963603854179382
1997; 0.38932253122329713
1998; 0.41773657004038495
1999; 0.37211719353993733


In [13]:
np.save('./saves_recall/recovered_s_control_030.npy',recovered_s_control)
np.save('./saves_recall/recovered_s_hlesion_030.npy',recovered_s_hlesion)
np.save('./saves_recall/recovered_s_control_lur_030.npy',recovered_s_control_lur)
np.save('./saves_recall/recovered_s_hlesion_lur_030.npy',recovered_s_hlesion_lur)

np.save('./saves_recall/error_known_030.npy', error_known)
np.save('./saves_recall/error_dual_030.npy', error_dual)
np.save('./saves_recall/error_mono_030.npy', error_mono)

np.save('./saves_recall/proj_x_true_030.npy', proj_x_true)
np.save('./saves_recall/proj_x_known_030.npy', proj_x_known)
np.save('./saves_recall/proj_x_dual_030.npy', proj_x_dual)
np.save('./saves_recall/proj_x_mono_030.npy', proj_x_mono)