In [15]:
import numpy as np
from scipy import sparse
import os
from tqdm import trange
import matplotlib.pyplot as plt

In [6]:
from rsnn.utils.analysis import get_phis
from rsnn.utils.utils import load_object_from_file

In [3]:
A = np.identity

In [81]:
N = 1000
A = sparse.diags_array(np.ones(N-1), offsets=-1).tocsc()

In [82]:
%timeit A[0] = np.random.normal(size=N)

88.1 µs ± 515 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [83]:
%timeit A@A

69.4 µs ± 408 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [84]:
N = 1000
A = sparse.diags_array(np.ones(N-1), offsets=-1).tolil()

In [85]:
%timeit A[0] = np.random.normal(size=N)

161 µs ± 339 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [86]:
%timeit A@A

172 µs ± 321 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [87]:
N = 1000
A = sparse.diags_array(np.ones(N-1), offsets=-1).toarray()

In [88]:
%timeit A[0] = np.random.normal(size=N)

27.1 µs ± 64.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [89]:
%timeit A@A

34.1 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [113]:
Phi = np.random.normal(size=(N,N))
Phi = Phi / np.sum(Phi, axis=1, keepdims=True)
eigvals = np.linalg.eigvals(Phi)
np.topnp.abs(eigvals)

In [103]:
%%timeit
N = 200
A = np.diag(np.ones(N-1),-1)
Phi = np.identity(N)

for m in range(N):
    A[0] = np.random.normal(size=N)
    A[0] /= np.sum(A[0])
    Phi = A@Phi

88 ms ± 671 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [104]:
%%timeit
N = 200
A = sparse.diags_array(np.ones(N-1), offsets=-1).tocsc()
Phi = np.identity(N)

for m in range(N):
    r = np.random.normal(size=N)
    r /= r.sum()
    A[0] = r
    Phi = A@Phi

20.8 ms ± 80.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [70]:
exp_dir = os.path.join("..","scripts","temporal_stability","data","100_400_10_10_20", "exp_0")

spike_trains = load_object_from_file(os.path.join(exp_dir, "spike_trains.pkl"))
network = load_object_from_file(os.path.join(exp_dir, "network_0.pkl"))

In [77]:
t = np.linspace(0,10,10).reshape(-1,1)
it = np.arange(2).reshape(1,-1)
network.neurons[0].input_kernel(t)

array([[0.        ],
       [0.99426591],
       [0.65461073],
       [0.32323989],
       [0.14187774],
       [0.05838145],
       [0.02306252],
       [0.00885735],
       [0.00333232],
       [0.0012341 ]])

In [72]:
network.recall, network.precision

([0.27657094746702204,
  0.11831001637011647,
  0.12215007908852503,
  0.11047481720949602,
  0.09893901038837587,
  0.09835836032074957,
  0.11176721962102952,
  0.10132903609157683],
 [0.3987795368718522,
  0.14073764096439728,
  0.14499502272834683,
  0.13462949210309483,
  0.14193861592283633,
  0.13570128543841573,
  0.13156181960639254,
  0.13552746449581085])

In [34]:
def get_phis(neurons, firing_times, period):
    flat_neurons = [neuron for neuron in neurons for _ in firing_times[neuron.idx]] 
    flat_firing_times = np.concatenate(firing_times)
    
    indices = np.argsort(flat_firing_times)

    M = len(flat_firing_times)
    # N = len(flat_firing_times)  # TODO: adapt the number of last spikes to consider

    Phi = np.identity(M)
    A = np.zeros((M, M))
    A[1:, :-1] = np.identity(M - 1)

    for m in trange(M): # at firing time s_m
        neuron = flat_neurons[indices[m]]

        # current spike is spikes[N + m]
        for n in range(M):
            select = neuron.sources == flat_neurons[indices[(m - n - 1)%M]].idx
            A[0, n] = np.sum(neuron.weights[select] * neuron.input_kernel_prime((flat_firing_times[indices[m]] - flat_firing_times[indices[(m - n - 1)%M]] - neuron.delays[select])%period))
            # if neuron.idx == flat_neurons[indices[(m - n - 1)%M]].idx:
            #     A[0, n] -= neuron.refractory_kernel_prime((flat_firing_times[indices[m]] - flat_firing_times[indices[(m - n - 1)%M]])%period)

        A[0] /= np.sum(A[0])
        Phi = A @ Phi

    # return -np.sort(-np.abs(np.linalg.eigvals(Phi)))
    return np.linalg.eigvals(Phi)

In [35]:
network.phis = get_phis(network.neurons, spike_trains, 50)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 863/863 [00:34<00:00, 24.67it/s]


In [36]:
network.phis[:3]

array([ 193.5910525 +0.j, -185.20601372+0.j,   41.83229441+0.j])

In [17]:
period = 50
flat_neurons = [neuron for neuron in network.neurons for _ in spike_trains[neuron.idx]] 
flat_firing_times = np.concatenate(spike_trains)

indices = np.argsort(flat_firing_times)

M = len(flat_firing_times)
# N = len(flat_firing_times)  # TODO: adapt the number of last spikes to consider

Phi = np.identity(M)
A = np.zeros((M, M))
A[1:, :-1] = np.identity(M - 1)

neuron = flat_neurons[indices[0]]

# current spike is spikes[N + m]
for n in range(M):
    select = neuron.sources == flat_neurons[indices[(0 - n - 1)%M]].idx
    A[0, n] = np.sum(neuron.weights[select] * neuron.input_kernel_prime((flat_firing_times[indices[0]] - flat_firing_times[indices[(0 - n - 1)%M]] - neuron.delays[select])%period))
    # if neuron.idx == flat_neurons[indices[(m - n - 1)%M]].idx:
    #     A[0, n] -= neuron.refractory_kernel_prime((flat_firing_times[indices[m]] - flat_firing_times[indices[(m - n - 1)%M]])%period)

A[0] /= np.sum(A[0])

In [23]:
np.count_nonzero(A) /(M**2)

0.0022799015533675544

In [29]:
A_approx = np.round(A, 9)
np.count_nonzero(A_approx) /(M**2)

0.0017790751226219137

In [32]:
(A_approx < 0).sum()

210

In [None]:
def get_phis(neurons, firing_times, period):
    flat_neurons = [neuron for neuron in neurons for _ in firing_times[neuron.idx]] 
    flat_firing_times = np.concatenate(firing_times)
    
    indices = np.argsort(flat_firing_times)

    M = len(flat_firing_times)
    # N = len(flat_firing_times)  # TODO: adapt the number of last spikes to consider

    Phi = np.identity(M)
    A = np.zeros((M, M))
    A[1:, :-1] = np.identity(M - 1)

    for m in trange(M): # at firing time s_m
        neuron = flat_neurons[indices[m]]

        # current spike is spikes[N + m]
        for n in range(M):
            select = neuron.sources == flat_neurons[indices[(m - n - 1)%M]].idx
            A[0, n] = np.sum(neuron.weights[select] * neuron.input_kernel_prime((flat_firing_times[indices[m]] - flat_firing_times[indices[(m - n - 1)%M]] - neuron.delays[select])%period))
            # if neuron.idx == flat_neurons[indices[(m - n - 1)%M]].idx:
            #     A[0, n] -= neuron.refractory_kernel_prime((flat_firing_times[indices[m]] - flat_firing_times[indices[(m - n - 1)%M]])%period)

        A[0] /= np.sum(A[0])
        Phi = A @ Phi

    # return -np.sort(-np.abs(np.linalg.eigvals(Phi)))
    return np.linalg.eigvals(Phi)