In [1]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sparentdir = os.path.dirname(parentdir)
sys.path.insert(0,parentdir)
sys.path.insert(0,sparentdir)

import datetime
import numpy as np
import scipy.stats as stats
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from matplotlib import gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import create
import generate
import recall
import simulations
import simsave
import load_curves

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
import matplotlib as mpl
import copy
import pickle
import numpy as np
import pandas as pd
import time
import seaborn as sns


N = 100
p = 0.5
c = 0.65
r = 0.3
M = 30

set_of_Ss = np.linspace(0,5.0,51)
set_of_Cs = np.linspace(0.4,0.7,13)
s_space = np.linspace(-10.0,15.0,10001)

In [2]:
#Get discretized posterior over S given a pattern and a connection matrix. 
def get_posterior_s(W, x, M, e_s2, frac, prior, s_space):
    
    N = len(x)
    count = (N * (N - 1)) / 2
    
    w_on = np.zeros(np.int(count))
    w_on[:np.int(count*frac)] = 1
    np.random.shuffle(w_on)

    mask = np.zeros((N, N))
    mask[np.triu_indices(N, 1)] = w_on
    W_f = mask * W

    total = 4*np.sum(W_f*(-1)**(x[:, None]+x))
    

    mu_W = total / (count*frac)
    var_W = (M - 1) * e_s2 / (count*frac)
    
    prob_s = stats.norm(mu_W, var_W).pdf(s_space) * prior
    
    return prob_s

In [3]:
#Store Results from Simulation

trials = 2000

recovered_s_control     = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
recovered_s_hlesion     = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

recovered_s_control_lur = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
recovered_s_hlesion_lur = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

error_known             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
error_dual              = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
error_mono              = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

proj_x_true             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
proj_x_known            = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
proj_x_dual             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])
proj_x_mono             = np.zeros([len(set_of_Cs), len(set_of_Ss), trials])

In [4]:
# Define the Background distribution
k_back = 2.0
mu_back = 1.0
background_dist = stats.gamma(k_back, 0, mu_back/k_back)

mu_P0 = mu_back
sigma_P0 = np.sqrt(mu_back*mu_back/k_back)
e_s2_P0 = mu_P0**2 + sigma_P0**2

In [5]:
update_s_known = create.update_s_known()
update_s_frac = create.update_s_frac(mu_P0, 10000, s_factor=1.0, beta=1000, frac=1.0)
frac = 1.0

flat = stats.norm(mu_P0, 10000).pdf
flat_v = flat(s_space) / np.sum(flat(s_space))
prior_s_v = flat_v

start_time = time.time()

for t in range(trials):
    print(str(t) + '; '+str((time.time() - start_time)/60))
    start_time = time.time()

    #Generate the background patterns and strengths
    xs, ss    = generate.memories_gaussian(N, p, M, mu_P0, sigma_P0)
    ss        = background_dist.rvs(N)

    xt = generate.noise(xs[0],r)

    for i_c, c in enumerate(set_of_Cs):

        update_x = create.update_x(lambda snr: c, e_s2 = e_s2_P0, beta=1000)
        
        for i_s, s in enumerate(set_of_Ss):

            ss[0] = s

            W = generate.pseudo_matrix(xs, ss, 0.5)
            proj_x_true[i_c, i_s, t] = np.sum(xs[0])

            s0 = update_s_frac(W, xt, M, mu_P0)/ np.max([0.15,(1-2*r)**2])

            x_samples, s_samples, _ = recall.maps(W, M, xt, s0, update_x, update_s_frac, samples = 100)
            x_recall                = np.median(x_samples[-50:],axis=0)
            
            post_s_v = get_posterior_s(W, x_recall, M, e_s2_P0, frac, prior_s_v, s_space)
            recovered_s_control[i_c, i_s, t] = s_space[np.argmax(post_s_v)]

            post_s_v = get_posterior_s(W, xt, M, e_s2_P0, frac, prior_s_v, s_space)
            recovered_s_hlesion[i_c, i_s, t] = s_space[np.argmax(post_s_v)]

            error_dual[i_c, i_s, t] = np.sum((xs[0]+x_recall)%2)
            proj_x_dual[i_c, i_s, t] = np.sum(x_recall)


            ### Mono

            x_samples, s_samples, _ = recall.maps(W, M, xt, mu_P0, update_x, update_s_known, samples = 100)
            x_recall                = np.median(x_samples[-50:],axis=0)

            error_mono[i_c, i_s, t] = np.sum((xs[0]+x_recall)%2)
            proj_x_mono[i_c, i_s, t] = np.sum(x_recall)

            ### Known

            x_samples, s_samples, _ = recall.maps(W, M, xt, ss[0], update_x, update_s_known, samples = 100)
            x_recall                = np.median(x_samples[-50:],axis=0)

            error_known[i_c, i_s, t] = np.sum((xs[0]+x_recall)%2)
            proj_x_known[i_c, i_s, t] = np.sum(x_recall)


0; 1.9367535909016927e-05
1; 0.47741742531458536
2; 0.4941345532735189
3; 0.491377580165863
4; 0.4949814518292745
5; 0.539271104335785
6; 0.5124054511388143
7; 0.4665806372960409
8; 0.4980383038520813
9; 0.48194491465886435
10; 0.4513506571451823
11; 0.4707719842592875
12; 0.492418909072876
13; 0.46156089305877684
14; 0.5731915950775146
15; 0.525388526916504
16; 0.4920309543609619
17; 0.5090943773587545
18; 0.45156680345535277
19; 0.49099188645680747
20; 0.4590074022610982
21; 0.5034505804379781
22; 0.5075423717498779
23; 0.5031028946240743
24; 0.5208991885185241
25; 0.5561300834019979
26; 0.45925234953562416
27; 0.5104005614916484
28; 0.5105185945828755
29; 0.5033763964970907
30; 0.4976990421613057
31; 0.4817663033803304
32; 0.5284722844759623
33; 0.4931317726771037
34; 0.47011335293451945
35; 0.5227096796035766
36; 0.5034657478332519
37; 0.47594717343648274
38; 0.47610428333282473
39; 0.46886346737543744
40; 0.4629563808441162
41; 0.45190436442693077
42; 0.4835560997327169
43; 0.5207

343; 0.6539507190386454
344; 0.6168273448944092
345; 0.5924024184544882
346; 0.6129430492719015
347; 0.4970755338668823
348; 0.5429632186889648
349; 0.5043697118759155
350; 0.4877288818359375
351; 0.5698277354240417
352; 0.5735645532608032
353; 0.6328189889589946
354; 0.5961107929547628
355; 0.5769414742787679
356; 0.5867369016011555
357; 0.5489816506703694
358; 0.513168195883433
359; 0.5554106553395589
360; 0.5428686300913493
361; 0.5653486132621766
362; 0.5996812105178833
363; 0.6602666974067688
364; 0.5657639424006145
365; 0.6676496227582296
366; 0.6189723610877991
367; 0.521604839960734
368; 0.5357272267341614
369; 0.5744308710098267
370; 0.4843512733777364
371; 0.509546426932017
372; 0.5680715123812358
373; 0.5739622791608174
374; 0.5545584956804911
375; 0.5457168181737264
376; 0.6244531591733297
377; 0.5133617957433064
378; 0.48668606678644816
379; 0.5164118448893229
380; 0.582772719860077
381; 0.5272354284922282
382; 0.5397653500239055
383; 0.5541176557540893
384; 0.546828039487

684; 0.4941906015078227
685; 0.5041525443394979
686; 0.5062025586764017
687; 0.43681176503499347
688; 0.5163875699043274
689; 0.5294958353042603
690; 0.5441622972488404
691; 0.5396424810091655
692; 0.520952562491099
693; 0.588152289390564
694; 0.5316493431727092
695; 0.5384133736292521
696; 0.5982436696688335
697; 0.5842249910036723
698; 0.5742371996243795
699; 0.46671457290649415
700; 0.5212941447893779
701; 0.5086944739023844
702; 0.523108692963918
703; 0.5190306027730306
704; 0.4833840767542521
705; 0.4807991663614909
706; 0.5153059999148051
707; 0.4969613989194234
708; 0.5054991920789083
709; 0.49805616537729896
710; 0.504584006468455
711; 0.4704504529635111
712; 0.49230770270029706
713; 0.5053367932637532
714; 0.5144594589869181
715; 0.4810341715812683
716; 0.5077027956644694
717; 0.47573699553807575
718; 0.4607575337092082
719; 0.5584428389867147
720; 0.5199103792508443
721; 0.6195679704348246
722; 0.5239118615786235
723; 0.48244370619455973
724; 0.5083980361620585
725; 0.5256148

1023; 0.4597054998079936
1024; 0.48159643411636355
1025; 0.5183304150899252
1026; 0.47148427963256834
1027; 0.5081252853075663
1028; 0.4914876659711202
1029; 0.47311511437098186
1030; 0.45412089824676516
1031; 0.5200117985407512
1032; 0.5005654573440552
1033; 0.4928829352060954
1034; 0.46853976249694823
1035; 0.5018139640490215
1036; 0.5318201939264934
1037; 0.5207550446192424
1038; 0.5008858005205791
1039; 0.5038075566291809
1040; 0.5270953933397929
1041; 0.4834585388501485
1042; 0.47933160066604613
1043; 0.47677567799886067
1044; 0.5181915442148844
1045; 0.484541118144989
1046; 0.5134063204129536
1047; 0.5096673846244812
1048; 0.5426082730293273
1049; 0.5068736672401428
1050; 0.49893073638280233
1051; 0.5000804543495179
1052; 0.5056069254875183
1053; 0.4911647915840149
1054; 0.4993683894475301
1055; 0.4884859800338745
1056; 0.49244701862335205
1057; 0.46772451798121134
1058; 0.4801157355308533
1059; 0.5142414371172587
1060; 0.47895859479904174
1061; 0.5041394074757893
1062; 0.5391855

1349; 0.5164804299672444
1350; 0.4952393094698588
1351; 0.5163169264793396
1352; 0.43148792584737145
1353; 0.5162864128748575
1354; 0.4433845321337382
1355; 0.44211804469426474
1356; 0.4558541576067607
1357; 0.4620124061902364
1358; 0.45783912340799965
1359; 0.4859451492627462
1360; 0.48812702894210813
1361; 0.5143723209698995
1362; 0.4533719778060913
1363; 0.4484055678049723
1364; 0.4396982232729594
1365; 0.46796782811482746
1366; 0.4926478385925293
1367; 0.48637359142303466
1368; 0.44814687172571815
1369; 0.4960824251174927
1370; 0.5043420354525249
1371; 0.47002150615056354
1372; 0.4418397665023804
1373; 0.434943687915802
1374; 0.5170129179954529
1375; 0.47610653638839723
1376; 0.4529983083407084
1377; 0.4700552821159363
1378; 0.4694670597712199
1379; 0.44308145840962726
1380; 0.46932993332544964
1381; 0.47895154158274333
1382; 0.47861217260360717
1383; 0.4697612404823303
1384; 0.5066753546396892
1385; 0.4678925355275472
1386; 0.456574821472168
1387; 0.47750598589579263
1388; 0.49663

1673; 0.4946162978808085
1674; 0.4401360114415487
1675; 0.4749139229456584
1676; 0.46991730531056725
1677; 0.479740051428477
1678; 0.4682377457618713
1679; 0.47552931706110635
1680; 0.46703736782073973
1681; 0.5112707177797954
1682; 0.438254710038503
1683; 0.4863054792086283
1684; 0.4839175224304199
1685; 0.4558199167251587
1686; 0.4782140851020813
1687; 0.47080647150675453
1688; 0.491413692633311
1689; 0.4893264532089233
1690; 0.42924652496973675
1691; 0.4674259583155314
1692; 0.5052291274070739
1693; 0.495894185702006
1694; 0.479354457060496
1695; 0.47745385964711506
1696; 0.47813984950383503
1697; 0.48380202452341714
1698; 0.4812036395072937
1699; 0.4712600668271383
1700; 0.4652483105659485
1701; 0.4687760591506958
1702; 0.46947329441706337
1703; 0.5013271570205688
1704; 0.4817326347033183
1705; 0.48112926483154295
1706; 0.489685328801473
1707; 0.43857805331548055
1708; 0.4703304608662923
1709; 0.48847026824951173
1710; 0.46789042552312216
1711; 0.4578484574953715
1712; 0.4850017388

1997; 0.4938201347986857
1998; 0.4911506175994873
1999; 0.4853115955988566


In [6]:
np.save('./saves_recall/recovered_s_control_030p.npy',recovered_s_control)
np.save('./saves_recall/recovered_s_hlesion_030p.npy',recovered_s_hlesion)
np.save('./saves_recall/recovered_s_control_lur_030p.npy',recovered_s_control_lur)
np.save('./saves_recall/recovered_s_hlesion_lur_030p.npy',recovered_s_hlesion_lur)

np.save('./saves_recall/error_known_030p.npy', error_known)
np.save('./saves_recall/error_dual_030p.npy', error_dual)
np.save('./saves_recall/error_mono_030p.npy', error_mono)

np.save('./saves_recall/proj_x_true_030p.npy', proj_x_true)
np.save('./saves_recall/proj_x_known_030p.npy', proj_x_known)
np.save('./saves_recall/proj_x_dual_030p.npy', proj_x_dual)
np.save('./saves_recall/proj_x_mono_030p.npy', proj_x_mono)