In [1]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sparentdir = os.path.dirname(parentdir)
sys.path.insert(0,parentdir)
sys.path.insert(0,sparentdir)

import datetime
import numpy as np
import scipy.stats as stats
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from matplotlib import gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import create
import generate
import recall
import simulations
import simsave
import load_curves

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
import matplotlib as mpl
import copy
import pickle
import numpy as np
import pandas as pd
import time
import seaborn as sns


N = 100
p = 0.5
c = 0.65
r = 0.2
M = 30

set_of_Ss = np.linspace(0,5.0,51)
set_of_Cs = np.linspace(0.4,0.7,13)
s_space = np.linspace(-10.0,15.0,10001)

In [2]:
#Get discretized posterior over S given a pattern and a connection matrix. 
def get_posterior_s(W, x, M, e_s2, frac, prior, s_space):
    
    N = len(x)
    count = (N * (N - 1)) / 2
    
    w_on = np.zeros(np.int(count))
    w_on[:np.int(count*frac)] = 1
    np.random.shuffle(w_on)

    mask = np.zeros((N, N))
    mask[np.triu_indices(N, 1)] = w_on
    W_f = mask * W

    total = 4*np.sum(W_f*(-1)**(x[:, None]+x))
    

    mu_W = total / (count*frac)
    var_W = (M - 1) * e_s2 / (count*frac)
    
    prob_s = stats.norm(mu_W, var_W).pdf(s_space) * prior
    
    return prob_s

In [3]:
#Store Results from Simulation

trials = 2000

recovered_s_control     = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
recovered_s_hlesion     = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

recovered_s_control_lur = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
recovered_s_hlesion_lur = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

error_known             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
error_dual              = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
error_mono              = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

proj_x_true             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
proj_x_known            = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
proj_x_dual             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
proj_x_mono             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

In [4]:
# Define the Background distribution
k_back = 2.0
mu_back = 1.0
background_dist = stats.gamma(k_back, 0, mu_back/k_back)

mu_P0 = mu_back
sigma_P0 = np.sqrt(mu_back*mu_back/k_back)
e_s2_P0 = mu_P0**2 + sigma_P0**2

In [5]:
update_s_known = create.update_s_known()
update_s_frac = create.update_s_frac(mu_P0, 10000, s_factor=1.0, beta=1000, frac=1.0)
frac = 1.0

flat = stats.norm(mu_P0, 10000).pdf
flat_v = flat(s_space) / np.sum(flat(s_space))
prior_s_v = flat_v

start_time = time.time()

for t in range(trials):
    print(str(t) + '; '+str((time.time() - start_time)/60))
    start_time = time.time()

    #Generate the background patterns and strengths
    xs, ss    = generate.memories_gaussian(N, p, M, mu_P0, sigma_P0)
    ss        = background_dist.rvs(N)

    xt = generate.noise(xs[0],r)

    for i_c, c in enumerate(set_of_Cs):

        update_x = create.update_x(lambda snr: c, e_s2 = e_s2_P0, beta=1000)
        
        for i_s, s in enumerate(set_of_Ss):

            ss[0] = s

            W = generate.weight_matrix(xs, ss, 0.5)
            proj_x_true[i_c, i_s, t] = np.sum(xs[0])

            s0 = update_s_frac(W, xt, M, mu_P0)/ np.max([0.15,(1-2*r)**2])

            x_samples, s_samples, _ = recall.maps(W, M, xt, s0, update_x, update_s_frac, samples = 100)
            x_recall                = np.median(x_samples[-50:],axis=0)
            
            post_s_v = get_posterior_s(W, x_recall, M, e_s2_P0, frac, prior_s_v, s_space)
            recovered_s_control[i_c, i_s, t] = s_space[np.argmax(post_s_v)]

            post_s_v = get_posterior_s(W, xt, M, e_s2_P0, frac, prior_s_v, s_space)
            recovered_s_hlesion[i_c, i_s, t] = s_space[np.argmax(post_s_v)]

            error_dual[i_c, i_s, t] = np.sum((xs[0]+x_recall)%2)
            proj_x_dual[i_c, i_s, t] = np.sum(x_recall)


            ### Mono

            x_samples, s_samples, _ = recall.maps(W, M, xt, mu_P0, update_x, update_s_known, samples = 100)
            x_recall                = np.median(x_samples[-50:],axis=0)

            error_mono[i_c, i_s, t] = np.sum((xs[0]+x_recall)%2)
            proj_x_mono[i_c, i_s, t] = np.sum(x_recall)

            ### Known

            x_samples, s_samples, _ = recall.maps(W, M, xt, ss[0], update_x, update_s_known, samples = 100)
            x_recall                = np.median(x_samples[-50:],axis=0)

            error_known[i_c, i_s, t] = np.sum((xs[0]+x_recall)%2)
            proj_x_known[i_c, i_s, t] = np.sum(x_recall)


0; 1.0208288828531902e-05
1; 0.42082829078038536
2; 0.46512332757314045
3; 0.36711854139963784
4; 0.40879948139190675
5; 0.37010738054911296
6; 0.42506951491038003
7; 0.37033714056015016
8; 0.4546724279721578
9; 0.382108211517334
10; 0.40542630354563397
11; 0.45198352734247843
12; 0.37492748896280925
13; 0.34879671335220336
14; 0.3513303240140279
15; 0.3659652233123779
16; 0.37788283824920654
17; 0.35502455234527586
18; 0.3961110552151998
19; 0.34583155314127606
20; 0.32044535080591835
21; 0.4433298865954081
22; 0.3407904028892517
23; 0.45308028856913246
24; 0.3574668407440186
25; 0.37935314178466795
26; 0.4402206023534139
27; 0.43931852976481117
28; 0.4280788739522298
29; 0.34882022539774576
30; 0.34480199416478474
31; 0.33431840340296426
32; 0.42635337511698407
33; 0.36597362359364827
34; 0.31893765131632484
35; 0.4742465575536092
36; 0.4880930225054423
37; 0.3572260578473409
38; 0.4525296131769816
39; 0.3642204443613688
40; 0.47445542017618814
41; 0.3238323171933492
42; 0.4264123121

341; 0.4466460903485616
342; 0.3535556356112162
343; 0.40310736497243244
344; 0.38012447357177737
345; 0.3971664587656657
346; 0.3836017489433289
347; 0.3799813508987427
348; 0.3619847615559896
349; 0.41534281969070436
350; 0.4393507599830627
351; 0.3359441637992859
352; 0.35282341241836546
353; 0.3941325743993123
354; 0.4112521727879842
355; 0.46434847116470335
356; 0.4025121609369914
357; 0.3898046533266703
358; 0.36338074604670206
359; 0.37536060412724814
360; 0.43555645942687987
361; 0.39479856888453163
362; 0.4440123915672302
363; 0.3827550053596497
364; 0.3469746232032776
365; 0.32107433478037517
366; 0.3845272501309713
367; 0.4029255708058675
368; 0.3785644690195719
369; 0.40601238012313845
370; 0.36257540384928383
371; 0.40822689135869344
372; 0.37363774379094444
373; 0.4406944712003072
374; 0.3746981382369995
375; 0.3655571103096008
376; 0.3692396124204
377; 0.5022567073504131
378; 0.40054181814193723
379; 0.3605930288632711
380; 0.31338181098302204
381; 0.37622180779774983
38

678; 0.34215993881225587
679; 0.3842954913775126
680; 0.41881189743677777
681; 0.35196364323298135
682; 0.42333229780197146
683; 0.37480581998825074
684; 0.37751572926839194
685; 0.3900676647822062
686; 0.3794615586598714
687; 0.380939257144928
688; 0.35291866064071653
689; 0.3483309785525004
690; 0.3600681940714518
691; 0.4117237448692322
692; 0.33936957518259686
693; 0.3580283999443054
694; 0.3893278400103251
695; 0.3882120688756307
696; 0.3336210330327352
697; 0.3645633300145467
698; 0.36925933758417767
699; 0.3280032634735107
700; 0.3387199362119039
701; 0.3330108284950256
702; 0.39268211523691815
703; 0.3793555418650309
704; 0.3749221603075663
705; 0.342602002620697
706; 0.3978835622469584
707; 0.39359392325083414
708; 0.39006452163060507
709; 0.3586022098859151
710; 0.3415138959884644
711; 0.37255679766337074
712; 0.46221171220143636
713; 0.35791671673456826
714; 0.3072458306948344
715; 0.4091405232747396
716; 0.35764132340749105
717; 0.4360741376876831
718; 0.3465961734453837
71

1015; 0.4559145967165629
1016; 0.3534919222195943
1017; 0.3498240828514099
1018; 0.36646903355916344
1019; 0.32287495136260985
1020; 0.34125728607177735
1021; 0.4417798479398092
1022; 0.3753102898597717
1023; 0.4640244722366333
1024; 0.38124927679697673
1025; 0.42772501707077026
1026; 0.3878180106480916
1027; 0.380270524819692
1028; 0.4486743489901225
1029; 0.38679123719533287
1030; 0.3510158658027649
1031; 0.38622322082519533
1032; 0.3910283923149109
1033; 0.4287941416104635
1034; 0.4279811859130859
1035; 0.38581430514653525
1036; 0.32467629909515383
1037; 0.40559819936752317
1038; 0.3945926586786906
1039; 0.43635935386021935
1040; 0.40582611163457233
1041; 0.39193238417307535
1042; 0.39760866165161135
1043; 0.40716602007548014
1044; 0.4243645787239075
1045; 0.36129236618677774
1046; 0.37871926625569663
1047; 0.43448715607325233
1048; 0.4392321348190308
1049; 0.4358256538709005
1050; 0.534736963113149
1051; 0.37098173300425213
1052; 0.4104319175084432
1053; 0.41492910385131837
1054; 0

1338; 0.4579637924830119
1339; 0.39043573538462323
1340; 0.37992389996846515
1341; 0.3981114188830058
1342; 0.36119996309280394
1343; 0.45571311712265017
1344; 0.40187137921651206
1345; 0.34929489294687904
1346; 0.4337146004041036
1347; 0.4047123074531555
1348; 0.35716182390848794
1349; 0.3593523542086283
1350; 0.42348877986272176
1351; 0.4395078818003337
1352; 0.39953778982162474
1353; 0.40612982511520385
1354; 0.3714919765790304
1355; 0.41441004276275634
1356; 0.38601980606714886
1357; 0.46453861395517987
1358; 0.4625599145889282
1359; 0.419103995958964
1360; 0.4091175357500712
1361; 0.3588599960009257
1362; 0.4362871011098226
1363; 0.37438029050827026
1364; 0.4120012760162354
1365; 0.38250526984532673
1366; 0.39589821100234984
1367; 0.3880910396575928
1368; 0.37989208698272703
1369; 0.41879911025365196
1370; 0.5340120355288188
1371; 0.3686189254124959
1372; 0.40081594785054525
1373; 0.4535877466201782
1374; 0.3645467758178711
1375; 0.4021950920422872
1376; 0.41533787647883097
1377; 

1661; 0.40729533036549886
1662; 0.4297516147295634
1663; 0.4320240418116252
1664; 0.44476160605748494
1665; 0.3569261153539022
1666; 0.3942153096199036
1667; 0.33817986249923704
1668; 0.36697185834248863
1669; 0.36018269856770835
1670; 0.3389775037765503
1671; 0.35702603658040366
1672; 0.4851383646329244
1673; 0.3890477975209554
1674; 0.3765144109725952
1675; 0.48209919532140094
1676; 0.3779218792915344
1677; 0.404222297668457
1678; 0.3662186781565348
1679; 0.33411020835240685
1680; 0.3600248177846273
1681; 0.33815986712773644
1682; 0.40326988299687705
1683; 0.31998914082845054
1684; 0.3430778503417969
1685; 0.38997071981430054
1686; 0.33916399081548054
1687; 0.36430272658665974
1688; 0.3460286855697632
1689; 0.40527030229568484
1690; 0.3550053000450134
1691; 0.40466910203297934
1692; 0.3299924294153849
1693; 0.3937896291414897
1694; 0.40097312529881796
1695; 0.3394983490308126
1696; 0.32320475578308105
1697; 0.3337643265724182
1698; 0.3400552352269491
1699; 0.3922266483306885
1700; 0.

1984; 0.31320778131484983
1985; 0.3451602776845296
1986; 0.32509618600209556
1987; 0.37347373565038045
1988; 0.3338062286376953
1989; 0.46038933197657267
1990; 0.3735928257306417
1991; 0.41064438422520955
1992; 0.34091459910074867
1993; 0.4184396743774414
1994; 0.3572628895441691
1995; 0.3330875595410665
1996; 0.4080255707105001
1997; 0.34141101042429606
1998; 0.38099408547083535
1999; 0.3676615436871847


In [6]:
np.save('./saves_recall/recovered_s_control_020.npy',recovered_s_control)
np.save('./saves_recall/recovered_s_hlesion_020.npy',recovered_s_hlesion)
np.save('./saves_recall/recovered_s_control_lur_020.npy',recovered_s_control_lur)
np.save('./saves_recall/recovered_s_hlesion_lur_020.npy',recovered_s_hlesion_lur)

np.save('./saves_recall/error_known_020.npy', error_known)
np.save('./saves_recall/error_dual_020.npy', error_dual)
np.save('./saves_recall/error_mono_020.npy', error_mono)

np.save('./saves_recall/proj_x_true_020.npy', proj_x_true)
np.save('./saves_recall/proj_x_known_020.npy', proj_x_known)
np.save('./saves_recall/proj_x_dual_020.npy', proj_x_dual)
np.save('./saves_recall/proj_x_mono_020.npy', proj_x_mono)