In [1]:
import sys
sys.path.insert(0, '../../../network')

In [2]:
import logging
import argparse
import numpy as np
from network import Population, RateNetwork
from learning import ReachingTask
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule
from sequences import GaussianSequence
import matplotlib.pyplot as plt
# np.set_printoptions(threshold=sys.maxsize)
logging.basicConfig(level=logging.INFO)

In [None]:
phi = ErrorFunction(mu=0.22, sigma=0.1).phi
ctx = Population(N=1000, tau=1e-2, phi=phi, name='ctx')
bg = Population(N=1000, tau=1e-2, phi=phi, name='bg')
plasticity = ThresholdPlasticityRule(x_f=0.5, q_f=0.8) 

S, P = 1, 3
sequences_ctx = [GaussianSequence(P,ctx.size, seed=11) for i in range(S)]
patterns_ctx = np.stack([s.inputs for s in sequences_ctx])
sequences_bg = [GaussianSequence(P,ctx.size, seed=367) for i in range(S)]
patterns_bg = np.stack([s.inputs for s in sequences_bg])

J_cc = SparseConnectivity(source=ctx, target=ctx, p=0.5)
synapse_cc = LinearSynapse(J_cc.K, A=6)
J_cc.store_attractors(patterns_ctx[0], patterns_ctx[0], synapse_cc.h_EE, plasticity.f, plasticity.g)
J_bb = SparseConnectivity(source=bg, target=bg, p=0.5)
synapse_bb = LinearSynapse(J_bb.K, A=6)
J_bb.store_attractors(patterns_bg[0], patterns_bg[0], synapse_bb.h_EE, plasticity.f, plasticity.g)
J_cb  = SparseConnectivity(source=bg, target=ctx, p=0.5)
synapse_cb = LinearSynapse(J_cb.K, A=3)
J_cb.store_attractors(patterns_bg[0], patterns_ctx[0], synapse_cb.h_EE, plasticity.f, plasticity.g)

J_bc = SparseConnectivity(source=ctx, target=bg, p=0.5)
synapse_bc = LinearSynapse(J_bc.K, A=0)
J_bc.store_sequences(patterns_ctx, patterns_bg, synapse_bc.h_EE, plasticity.f, plasticity.g)

net_ctx = RateNetwork(ctx, c_EE=J_cc, c_IE=J_bc, formulation=4)
net_bg = RateNetwork(bg, c_II=J_bb, c_EI=J_cb, formulation=4)

In [None]:
init_input_ctx = np.random.RandomState().normal(0,1,size=patterns_ctx[0][0].shape)
init_input_bg = np.random.RandomState().normal(0,1,size=patterns_bg[0][0].shape)
# init_input_ctx = patterns_ctx[0][0]
# init_input_bg = patterns_bg[0][0]
T=5
behaviors = net_ctx.simulate_learning(net_bg, T, init_input_ctx, init_input_bg, patterns_ctx[0], noise=13)

In [None]:
overlaps = sequences_ctx[0].overlaps(net_ctx, ctx, phi=phi)
correlations = sequences_ctx[0].overlaps(net_ctx, ctx, phi=phi, correlation=True)

In [None]:
for i, row in enumerate(correlations):
    plt.plot(row, label=str(i))
plt.legend()

In [None]:
np.save('/Users/stanleypark/Desktop/code/hebbian_sequence_learning/figures/article/1/output/behaviors.npy', 
       behaviors)

In [20]:
# behavior = np.array([0, 0, 1, 1, 2, 1, 2, 3, 0, 1, 1, 0, 2, 0, 1, 2])
behavior = np.random.randint(4, size=1000)

In [22]:
mouse = ReachingTask(2)

for i in range(len(behavior)):
    if i == 0:
        s1 = 'out'
        w1 = 0
        prev = None
    else:
        s0, w0 = s1, w1
        a0, a1 = behavior[i-1], behavior[i]
        mouse.detect_reward(s0, a0, w0)
        s1, w1 = mouse.state_transition(s0, a0, w0)
        mouse.td_learning(s0, a0, w0, s1, a1, w1)
#         if i > 1:
#             print("actions: " + str((prev[0][1], prev[1][1])) + "; Qvals:: " + str((mouse.Qvalues[prev[0][2], mouse.states.index(prev[0][0]), prev[0][1]], mouse.Qvalues[prev[1][2], mouse.states.index(prev[1][0]), prev[1][1]])))
        prev = [(s0, a0, w0), (s1, a1, w1)]
print(mouse.Qvalues)

[[[-0.32038344 -0.24919736 -0.62480946 -0.42317768]
  [-0.11438328  0.2551558  -0.5654098  -0.25738157]]

 [[-0.03481723 -0.21374299  0.67378942 -0.04914785]
  [ 0.0499738   0.22073229  0.79670358 -0.06371064]]]


In [None]:
np.zeros((2,2,4))[False]