In [1]:
import sys
sys.path.insert(0, '../../../network')

In [2]:
import logging
import argparse
import numpy as np
from network import Population, RateNetwork
from learning import ReachingTask
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule, set_connectivity 
from sequences import GaussianSequence
import matplotlib.pyplot as plt
import seaborn as sns
logging.basicConfig(level=logging.INFO)

In [3]:
phi = ErrorFunction(mu=0.22, sigma=0.1).phi
plasticity = ThresholdPlasticityRule(x_f=0.5, q_f=0.8)

# populations
ctx = Population(N=1000, tau=1e-2, phi=phi, name='ctx')
d1 = Population(N=1000, tau=1e-2, phi=phi, name='d1')
pops = np.array([ctx, d1])

# patterns 
S, P = 1, 3
sequences_ctx = [GaussianSequence(P,ctx.size, seed=114) for i in range(S)]
patterns_ctx = np.stack([s.inputs for s in sequences_ctx])
sequences_d1 = [GaussianSequence(P,d1.size, seed=29) for i in range(S)]
patterns_d1 = np.stack([s.inputs for s in sequences_d1])
patterns = [patterns_ctx, patterns_d1]

# connectivity probabilities
cp = np.array([[0.05, 0.05], 
               [0.05, 0.05]])
cw = np.array([[0, 0],
               [0, 0]])
A = np.array([[4, 0],
              [1, 4]])

plasticity_rule = np.array([[0, 0],
                            [1, 0]])

J = set_connectivity(pops, cp, cw, A, plasticity_rule, patterns, plasticity)

network = RateNetwork(pops, J, formulation=5)

INFO:connectivity:Building connections from ctx to ctx
INFO:connectivity:Storing attractors
100%|██████████| 1000/1000 [00:00<00:00, 27899.35it/s]
INFO:connectivity:Building connections from ctx to d1
INFO:connectivity:Storing sequences
  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1000/1000 [00:00<00:00, 28331.85it/s]
100%|██████████| 1/1 [00:00<00:00, 24.96it/s]
INFO:connectivity:Applying synaptic transfer function
INFO:connectivity:Building sparse matrix
INFO:connectivity:Building connections from ctx to d2
INFO:connectivity:Storing sequences
  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1000/1000 [00:00<00:00, 32789.26it/s]
100%|██████████| 1/1 [00:00<00:00, 30.80it/s]
INFO:connectivity:Applying synaptic transfer function
INFO:connectivity:Building sparse matrix
INFO:connectivity:Building connections from d1 to ctx
INFO:connectivity:Storing attractors
100%|██████████| 1000/1000 [00:00<00:00, 32430.52it/s]
INFO:connectivity:Building connections from d1 to d1
INFO:c

In [4]:
init_input_ctx = phi(patterns_ctx[0][0])
# init_input_bg = phi(patterns_bg[0][0])
# init_input_ctx = np.random.normal(net_ctx.size)
# init_input_d1 = np.random.normal(d1.size)
# init_input_d2 = np.random.normal(d2.size)
init_input_d1 = np.zeros(d1.size)
init_input_d2 = np.zeros(d2.size)
init_y1 = np.ones(d1.size) * .7
init_y2 = np.zeros(d2.size)
T=10
mouse = ReachingTask()

network.simulate_euler2(mouse, T, init_input_ctx, init_input_d1, 
                        patterns_ctx[0], patterns_d1[0], detection_thres=.23,
                        noise1=.13, noise2=0.13)

INFO:network:Integrating network dynamics
  0%|          | 0/9999 [00:00<?, ?it/s]

UnboundLocalError: local variable 'prev_action1' referenced before assignment

In [None]:
# overlaps = sequences_ctx[0].overlaps(net_ctx, ctx, phi=phi)
overlaps_ctx = sequences_ctx[0].overlaps(network.pops[0])
overlaps_d1 = sequences_d1[0].overlaps(network.pops[1])
overlaps_d2 = sequences_d2[0].overlaps(network.pops[2])

In [None]:
sns.set_style('white') 
colors = sns.color_palette('deep')


In [None]:
fig, axes = plt.subplots(3,1, sharex=True, sharey=True, tight_layout=True, figsize=(20,20))
axes[0].plot(overlaps_ctx[0], linestyle='solid', linewidth=3, color=colors[8])
axes[0].plot(overlaps_ctx[1], linestyle='dashed', linewidth=3, color=colors[0])
axes[0].plot(overlaps_ctx[2], linestyle='dotted', linewidth=3, color=colors[3])
axes[0].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
axes[0].set_title("CTX", fontsize=25)
axes[1].plot(overlaps_d1[0], linestyle='solid', linewidth=3, color=colors[8])
axes[1].plot(overlaps_d1[1], linestyle='dashed', linewidth=3, color=colors[0])
axes[1].plot(overlaps_d1[2], linestyle='dotted', linewidth=3, color=colors[3])
axes[1].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
axes[1].set_title("D1", fontsize=25)
axes[1].set_xlabel('Time (ms)', fontsize=20)
axes[2].plot(overlaps_d2[0], linestyle='solid', linewidth=3, color=colors[8])
axes[2].plot(overlaps_d2[1], linestyle='dashed', linewidth=3, color=colors[0])
axes[2].plot(overlaps_d2[2], linestyle='dotted', linewidth=3, color=colors[3])
axes[2].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
axes[2].set_title("D2", fontsize=25)
axes[2].set_xlabel('Time (ms)', fontsize=20)
fig.text(-0.01, 0.5, 'Overlap', va='center', rotation='vertical', fontsize=20)
plt.setp(axes, xlim=(0, 10000))
plt.setp(axes, ylim=(0, 0.4))
plt.figlegend(labels=['Aim', 'Reach', 'Lick'], fontsize=20)
plt.savefig('./d1d2.png', bbox_inches = "tight", format='png')

plt.show()


In [None]:
plt.plot(np.average(network.pops[1].state, axis=0), label='firing rate')
plt.plot(np.average(network.pops[1].depression * network.pops[1].state, axis=0), label='xy')
plt.xlim(0, 20000)
plt.legend()
plt.savefig('./d1avg.png', bbox_inches = "tight", format='png')




In [None]:
plt.plot(np.average(network.pops[2].state, axis=0), label='firing rate')
plt.plot(np.average(network.pops[2].depression * network.pops[2].state, axis=0), label='xy')
plt.xlim(0, 20000)
plt.legend()
plt.savefig('./d2avg.png', bbox_inches = "tight", format='png')