In [1]:
import torch
import torch.multiprocessing as mp

In [2]:
from rsnn.spike_sequences.sampling import sample_spike_sequences
from rsnn.spike_sequences.template import segment_spike_sequence
from rsnn.optimization.utils import compute_observation_matrices
from rsnn.optimization.optimization import compute_weights

In [3]:
L = 100
T = 100
Tr = 20

K = 500

wb = 0.1
taub = 60
a = 1
b = 0
theta = 1.0

Tr = 20
beta = 2  # spreading hyperparameter

impulse_resp = lambda t_: (t_ >= 0) * t_ / beta * torch.exp(1 - t_ / beta)
impulse_resp_deriv = lambda t_: (t_ >= 0) * 1 / beta * (1 - t_ / beta) * torch.exp(1 - t_ / beta)

delays = torch.FloatTensor(L, K).uniform_(0, taub)
sources = torch.randint(0, L, (L, K))

spike_sequences = sample_spike_sequences(L, T, Tr)

In [4]:
observation_matrices = []
for l in range(L):
    segmentation = segment_spike_sequence(spike_sequences[l], Tr, 1)
    observation_matrices.append(
        compute_observation_matrices(spike_sequences, segmentation, delays[l], sources[l], Tr, impulse_resp, impulse_resp_deriv)
    )

In [14]:
# %%timeit
# weights = torch.empty(L, K)
# for l in range(L):
#     weights[l] = compute_weights(*observation_matrices[l], wb, theta, a, b)
#     
# serial took 18.1s for 10 neurons -> 180s for 100 neurons

18.1 s ± 327 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
%%timeit
manager = mp.Manager()
return_dict = manager.dict()
jobs = []
for l in range(L):
    p = mp.Process(target=compute_weights, args=(l, return_dict, *observation_matrices[l], wb, theta, a, b))
    jobs.append(p)
    p.start()

for p in jobs:
    p.join()
    
weights = torch.stack([return_dict[l] for l in range(L)])

3min 5s ± 584 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
# https://pytorch.org/tutorials/intermediate/dist_tuto.html

In [12]:
%%timeit 
a = torch.rand(100, 100)
b = torch.rand(100, 100)
c = a@b

120 µs ± 330 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
torch.set_num_threads(8)

In [10]:
%%timeit 
a = torch.rand(100, 100)
b = torch.rand(100, 100)
c = a@b

140 µs ± 2.67 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
