In [1]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt

from batchAQUA_GPU import *
from batchAQUA_general import batchAQUA

In [2]:
T = 5      # s
dt = 0.1   # ms
N_iter = int(1000*T/dt)

N_neurons = 2000

x_ini = np.array([0, 0, 0])
x_start = np.full((N_neurons, 3), fill_value = x_ini)
t_start = np.zeros(N_neurons)

RS = {'name': 'RS', 'C': 100, 'k': 0.7, 'v_r': -60, 'v_t': -40, 'v_peak': 35,
     'a': 0.03, 'b': -2, 'c': -50, 'd': 100, 'e': 0., 'f': 0., 'tau': 0.}    # Class 1

params = []
for i in range(N_neurons):
    params.append(RS)

batch = batchAQUA_GPU(params)
batch.Initialise(x_start, t_start)

I_inj = cp.array([i*np.ones(N_iter) for i in range(N_neurons)])




In [3]:
print(batch.N_models)
print(type(batch.N_models))

2000
<class 'int'>


In [4]:
# simulate
X, t, spikes = batch.update_batch(dt, N_iter, I_inj)

100%|██████████| 49999/49999 [02:07<00:00, 393.24it/s]


N_models: 2000
N_iter: 50000
(2000, 50000)
[   0    0    0 ... 1999 1999 1999]


### GPU Results:

#### 2000 neuron    s simulated for 5 seconds at 0.1 ms resolution (50,000 iterations)

Time taken: 2 minutes 7 second


Major slow down is due to the python loop over timesteps. This effectively repeatedly generates calls to the GPU.



In [5]:
batch_og = batchAQUA(params)
batch_og.Initialise(x_start, t_start)

X2, t2, spikes2 = batch_og.update_batch(dt, N_iter, cp.asnumpy(I_inj))

100%|██████████| 49999/49999 [00:25<00:00, 1994.85it/s]


### CPU Results:

##### 2000 neurons simulated for 5 seconds at 0.1 ms resolution (50,000 iterations)

Time taken: 27 seconds 

In [6]:
spike_times = get_spike_times(X2, t2[0, :], batch_og.v_peak)
print(spike_times)

N_models: 2000
N_iter: 50000
(2000, 50000)
[   0    0    0 ... 1999 1999 1999]
[[0.0000e+00 1.0000e-01 2.0000e-01 ... 4.9997e+03 4.9998e+03 4.9999e+03]
 [0.0000e+00 1.0000e-01 2.0000e-01 ... 4.9997e+03 4.9998e+03 4.9999e+03]
 [0.0000e+00 1.0000e-01 2.0000e-01 ... 4.9997e+03 4.9998e+03 4.9999e+03]
 ...
 [0.0000e+00 1.0000e-01 2.0000e-01 ... 4.9997e+03 4.9998e+03 4.9999e+03]
 [0.0000e+00 1.0000e-01 2.0000e-01 ... 4.9997e+03 4.9998e+03 4.9999e+03]
 [0.0000e+00 1.0000e-01 2.0000e-01 ... 4.9997e+03 4.9998e+03 4.9999e+03]]
