# Basic Simulation Demo

## Intro

The simulation model roughly follows the model described [here](https://ieeexplore.ieee.org/document/5179043), while also leveraging the [dense-to-sparse conversion](https://docs.nvidia.com/cuda/cusparse/#cusparsedensetosparse) from the [cuSPARSE](https://docs.nvidia.com/cuda/cusparse) library.



## Simulation Steps

1. Voltage/Fired update (here: external kernel).
2. DenseToSparse conversion (using "write"-pointers).
3. "read"/"write"-pointer shifts.
4. Synaptic current update (kernel; using "read"-pointers).

## R-STDP (not implemented here)

A first attempt at incorporating [R-STDP](https://arxiv.org/abs/1705.09132) in this simulation-model came with a large performance (and V-RAM) cost likely due to [divergence](https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/#branching-and-divergence). [Dynamic parallelism](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-dynamic-parallelism) [should](https://www.sciencedirect.com/science/article/pii/S0925231218304168) bring some performance back.


## Code

In [1]:
import pycuda.autoinit
import pycuda.driver as cuda
cuda.Device(0).retain_primary_context()

import numpy as np
from pycuda.curandom import rand as curand
import torch

from sim_demo_utils import (
    display_side_by_side,
    make_neurons_states,
    make_n_delays,
    make_n_rep,
    make_n_weights,
    update_N_state_kernel,
    TorchHolder
)

### Simulation Variables

In [2]:
N = 4  # Total number of neurons
S = 3  # Number of synapses per Neuron
D = 3  # Number of possible delays
T = 1000  # Simulation duration

Neuron_types = TorchHolder((torch.arange(N) >= 0.2 * N).cuda().type(torch.int32))  # 0 (inhibitory) or 1 (excitatory)

# Neurons States (IZHIKEVICH MODEL) [rows: pt, u, v, a, b, c, d, i]
Neuron_states = make_neurons_states(N, Neuron_types.tensor)
rt = curand([3], dtype=np.float32)

# Network Representation:
#   column number: pre-synaptic neuron index in {0, ..., N-1},
#   row number:    (local) synapse index from 0 to S -1,
#   values:        post-synaptic neurons indices in {0, ..., N-1}.
Network_representation = TorchHolder(make_n_rep(S=S, N=N).tensor)
display_side_by_side(Neuron_types, Neuron_states, Network_representation,
                     titles=['Neuron_types', 'Neuron_states', 'Network_representation'])

neuron states:
	N0: pt=0.12, u=0.00, v=-65.00, a=-65.00, b=0.20, c=-79.70, d=-2.12, i=0.00
	N1: pt=0.61, u=0.00, v=-65.00, a=-65.00, b=0.23, c=-68.07, d=-12.77, i=0.00
	N2: pt=0.35, u=0.00, v=-65.00, a=-65.00, b=0.25, c=-65.03, d=-13.99, i=0.00
	N3: pt=0.10, u=0.00, v=-65.00, a=-65.00, b=0.21, c=-81.88, d=-7.25, i=0.00


synapses:
	pre-synaptic neuron -> [post-synaptic neuron0, post-synaptic neuron1, ....]
	0 ->  [3, 2, 1]
	1 ->  [0, 2, 1]
	2 ->  [0, 2, 3]
	3 ->  [3, 2, 0]


Unnamed: 0,0,1,2,3
0,0,1,1,1

Unnamed: 0,0,1,2,3
0,0.12,0.61,0.35,0.1
1,0.0,0.0,0.0,0.0
2,-65.0,-65.0,-65.0,-65.0
3,0.02,0.05,0.02,0.08
4,0.2,0.23,0.25,0.21
5,-79.7,-68.07,-65.03,-81.88
6,-2.12,-12.77,-13.99,-7.25
7,0.0,0.0,0.0,0.0

Unnamed: 0,0,1,2,3
0,3,0,0,3
1,2,2,2,2
2,1,1,3,0


In [3]:
Network_delays = make_n_delays(D, N, S)
Neuron_weights = make_n_weights(Neuron_types, Network_representation)
display_side_by_side(Network_delays, Neuron_weights,
                     titles=['Network_delays', 'Neuron_weights'])


delays: neuron, [first synapse index with delay 0, first synapse index with delay 1, ...]
	N0:  [0, 1, 2, 3]
	N1:  [0, 0, 0, 3]
	N2:  [0, 0, 0, 3]
	N3:  [0, 1, 1, 3]

weights:
	N0:  [-0.06, -0.41, -0.41]
	N1:  [5.55, 5.07, 5.24]
	N2:  [5.5, 5.64, 5.47]
	N3:  [5.01, 5.5, 5.17]


Unnamed: 0,0,1,2,3
0,0,0,0,0
1,1,0,0,1
2,2,0,0,1
3,3,3,3,3

Unnamed: 0,0,1,2,3
0,-0.06,5.55,5.5,5.01
1,-0.41,5.07,5.64,5.5
2,-0.41,5.24,5.47,5.17


In [4]:
fired = TorchHolder(torch.zeros(size=(N,), dtype=torch.float32).cuda())
firing_times = TorchHolder(torch.zeros(size=(15, N), dtype=torch.float32).flatten().cuda())
firing_idcs = TorchHolder(torch.zeros(size=firing_times.shape, dtype=torch.int32).flatten().cuda())
firing_counts = TorchHolder(torch.zeros(size=(1, T * 2), dtype=torch.int32).cuda())

fired.print_as_list('fired         =',)
print('firing_times  =', firing_times.as_array, f"(shape=({firing_times.shape[0]}, ))")
print('firing_idcs   =', firing_idcs.as_array, f"(shape=({firing_idcs.shape[0]}, ))")
print('firing_counts =', firing_counts.as_array, f"(shape=({firing_counts.shape[0]}, {firing_counts.shape[1]}))")

fired         = [0.0, 0.0, 0.0, 0.0]
firing_times  = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] (shape=(60, ))
firing_idcs   = [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (shape=(60, ))
firing_counts = [0 0 0 ... 0 0 0] (shape=(1, 2000))


### Voltage/Fired Update Kernel

```c++
__global__ void update_N_state(
    const int N,
    const float t,
    const float* r,
    const float* rt,
    float* N_states,
    const int* N_type,
    float* fired,
    const float thalamic_inh_input_current = 1.f,
    const float thalamic_exc_input_current = 1.f
)
{
    const int n = blockIdx.x * blockDim.x + threadIdx.x;

    if (n < N)
    {
        fired[n] = 0.f;

        float pt = N_states[n];
        float u = N_states[n + N];
        float v = N_states[n + 2 * N];
        const float a = N_states[n + 3 * N];
        const float b = N_states[n + 4 * N];
        const float c = N_states[n + 5 * N];
        const float d = N_states[n + 6 * N];
        float i = N_states[n + 7 * N];

        if (r[n] < pt)
        {
            const int ntype = N_type[n];
            i += (thalamic_exc_input_current * ntype + thalamic_exc_input_current * (1 - ntype)) * rt[n];
        }

        if (v > 30.f)
        {
            v = c;
            u = u + d;
            fired[n] = t;
        }

        v = v + 0.5f * (0.04f * v * v + 5 * v + 140 - u + i);
        v = v + 0.5f * (0.04f * v * v + 5 * v + 140 - u + i);
        u = u + a * (b * v - u);

        N_states[n + N] = u;
        N_states[n + 2 * N] = v;
        N_states[n + 7 * N] = 0.f;
    }
}
```

In [5]:
def update_N_state(t):
    r = curand([3], dtype=np.float32)
    update_N_state_kernel(np.int32(N), np.float32(t), r,  rt,
                          Neuron_states, Neuron_types, fired,
                          block =(32, 1, 1))

### Relevant Attributes (./.../snn_sim_demo.cuh)
```c++
struct CurrentUpdater
{

    [...]

    float* fired;               // during each step we collect firing times as floats
    // i.e. a neuron n fires at time t, then fired[n] = t (see voltage-update kernel)

    // read (meaning): the "current_update_"-kernel will be applied on all
    //                 firing indices between firing_idcs_read and firing_idcs_write - 1

    // write (meaning): the cusparseDenseToSparse conversion will be executed as if
    //                  the respective arrays would start where the "write"-pointers point


    float* firing_times_write;  // pointer (used by the cusparseDenseToSparse conversion)
    float* firing_times_read;   // pointer (used by the current-update kernel)
    float* firing_times;        // pointer to the start of array

    int* firing_idcs_write;   // pointer (used by the cusparseDenseToSparse conversion)
    int* firing_idcs_read;    // pointer (used by the current-update kernel)
    int* firing_idcs;         // pointer to the start of array

    int* firing_counts_write;  // pointer (used by the cusparseDenseToSparse conversion)
    int* firing_counts;        // pointer (used by the current-update kernel)

    cusparseHandle_t fired_handle;

    cusparseSpMatDescr_t firing_times_sparse;
    cusparseDnMatDescr_t firing_times_dense;

    void* fired_buffer{nullptr};
    size_t fired_buffer_size = 0;

    // counters and thresholds used in the "pointer-update"-function
    int n_fired = 0;
    int n_fired_total = 0;
    int n_fired_total_m1 = 0;  // number
    int n_fired_0 = 0;
    int n_fired_m1 = 0;

    int firing_counts_idx = 1;
    int firing_counts_idx_m1 = 1;

    int reset_firing_times_ptr_threshold;
    int reset_firing_count_idx_threshold;
    int n_fired_m1_to_end = 0;

    [...]
}
```

### Initialization (./.../snn_sim_demo.cu)
```c++
CurrentUpdater::CurrentUpdater(...){

    [...]

    // Pointer Initializations
    fired = fired_;
    firing_times = firing_times_;
    firing_idcs = firing_idcs_;
    firing_counts = firing_counts_;

    // initially all pointers point to the start of the respective array
    firing_times_write = firing_times;
    firing_times_read = firing_times;

    firing_idcs_write = firing_idcs;
    firing_idcs_read = firing_idcs;

    firing_counts_write = firing_counts;

    reset_firing_times_ptr_threshold = 13 * N;
    reset_firing_count_idx_threshold = 2 * T;

    // Cusparse Initializations (once is enough)
    checkCusparseErrors(cusparseCreate(&fired_handle));
    checkCusparseErrors(cusparseCreateDnMat(&firing_times_dense,
        1, N, N,
        fired,
        CUDA_R_32F, CUSPARSE_ORDER_ROW));

    checkCusparseErrors(cusparseCreateCsr(&firing_times_sparse, 1, N, 0,
        firing_counts_write,
        firing_idcs_write,
        firing_times_write,
        CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
        CUSPARSE_INDEX_BASE_ZERO, CUDA_R_32F));

    checkCusparseErrors(cusparseDenseToSparse_bufferSize(
        fired_handle, firing_times_dense, firing_times_sparse,
        CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
        &fired_buffer_size));
    checkCudaErrors(cudaMalloc(&fired_buffer, fired_buffer_size));

}
```

### CuSparse Usage

#### Setup

With each simulation step we will convert the "dense matrix" **fired** into the [CSR format](https://docs.nvidia.com/cuda/cusparse/index.html?highlight=cusparse_index_base_zero#compressed-sparse-row-format-csr).

_cusparseDnMatDescr_t_ **firing_times_dense**: $1 \times N$ dense matrix

* array: **fired**
* order: row-major

_cusparseSpMatDescr_t_ **firing_times_sparse**: $1 \times N$ sparse matrix in the CSR format

* Row offsets array: **firing_counts_write**
* Column indices array: **firing_idcs_write**
* Values array: **firing_times_write**

#### Example

Let's assume the "write" pointers point to the beginning of their respective arrays, i.e. **firing_counts_write** == **firing_counts**, etc..

1. The Neurons $N1$ and $N3$ fire at time $t$, i.e.:

<table>
    <tbody>
        <tr>
            <td style="background-color: #aaa"><b>fired: </b></td>
            <td> 0 </td>
            <td style="background-color: #bbb"> t </td>
            <td> 0 </td>
            <td style="background-color: #bbb"> t </td>
        </tr>
    </tbody>
</table>


2. The dense-to-sparse conversion then yields:


<table>
    <tbody>
        <tr>
            <td style="background-color: #aaa"><b>firing_times (values): </b></td>
            <td> t </td>
            <td style="background-color: #bbb"> t </td>
            <td> 0 </td>
            <td style="background-color: #bbb"> ... </td>
            <td> 0 </td>
        </tr>
    </tbody>
</table>

<table>
    <tbody>
        <tr>
            <td style="background-color: #aaa"><b>firing_counts (row offsets):  </b></td>
            <td> 0 </td>
            <td style="background-color: #bbb"> 2 (=nzz)</td>
            <td> 0 </td>
            <td style="background-color: #bbb"> ... </td>
            <td> 0 </td>
        </tr>
    </tbody>
</table>

<table>
    <tbody>
        <tr>
            <td style="background-color: #aaa"><b>firing_idcs_write (column indices): </b></td>
            <td> 1 </td>
            <td style="background-color: #bbb"> 3 </td>
            <td> 0 </td>
            <td style="background-color: #bbb"> ... </td>
            <td> 0 </td>
        </tr>
    </tbody>
</table>


In practice, we use pointer arithmetics to cycle through the output arrays. This allows us to keep a firing history and therefore synaptic delays. The number of required threads is dependent of the number of nonzero elements in the matrix (**nnz**). Once the "read"-pointers are shifted read pointers, we have enough information to launch the "synaptic current update"-kernel.



In [6]:
# noinspection PyUnresolvedReferences
from sim_demo_precompiled import snn_sim_demo_cpp
synaptic_current_updater = snn_sim_demo_cpp.SynapticCurrentUpdater(
    N=N, S=S, D=D, T=T,
    N_rep=Network_representation.get_pointer(),
    N_delays=Network_delays.get_pointer(),
    N_types=Neuron_types.get_pointer(),
    N_states=Neuron_states.get_pointer(),
    N_weights=Neuron_weights.get_pointer(),
    fired=fired.get_pointer(),
    firing_times=firing_times.get_pointer(),
    firing_idcs=firing_idcs.get_pointer(),
    firing_counts=firing_counts.get_pointer())

def print_sim_state():
    Neuron_states.print_row_as_list(2, 'v             =')
    Neuron_states.print_row_as_list(7, 'i             =')
    n_fired_total = synaptic_current_updater.n_fired_total
    if synaptic_current_updater.n_fired > 0:
        print('\nfiring_counts =', firing_counts.as_array[: synaptic_current_updater.t * 2])
        print('firing_times  =',  list(firing_times.as_array[: n_fired_total]))
        print('firing_idcs   =',  list(firing_idcs.as_array[: n_fired_total]))

synaptic_current_updater

SynapticCurrentUpdater(N=4, S=3, D=3, T=1000, t=0)

### Synaptic Current Update Kernel  (./.../snn_sim_demo.cu)

```c++
__global__ void update_current_(
	const int N, const int S, const int D,
	const int* fired_idcs_read, const int* fired_idcs,
	const float* firing_times_read, const float* firing_times,
	const int* N_flags, const int* N_rep, float* N_weights, float* N_states, const int* N_delays,
	const int n_fired_m1_to_end, const int n_fired,
	const int t
)
{
	const int fired_idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (fired_idx < n_fired)
    {
        int n;  			// pre-synaptic neuron
        int firing_time;	// firing time of the pre-synaptic neuron

        if (fired_idx < n_fired_m1_to_end)
        {
            // global index of firing-array < len(fired-array)
            // -> use the trailing pointer
            n = fired_idcs_read[fired_idx];
            firing_time = __float2int_rn(firing_times_read[fired_idx]);
        }
        else
        {
            // global index of firing-array >= len(fired-array)
            // -> use the 'normal' pointer
            n = fired_idcs[fired_idx - n_fired_m1_to_end];
            firing_time = __float2int_rn(firing_times[fired_idx - n_fired_m1_to_end]);
        }

        int delay = t - firing_time;  // time passed since the neuron fired
        const int delay_idx = n + N * (delay);

        int snk_N; 		// post-synaptic Neuron-ID
        int idx;		// synapse-index

        // row-index of the first synapse with a delay d_next = delay + 1
        int s_end = N_delays[delay_idx + N];

        float w;		// weight of the synapse

        // loop thourgh all synapses with a delay d == delay
        for (int s = N_delays[delay_idx]; s < s_end; s++)
        {
            idx = n + N * s;		// synapse-index
            snk_N = N_rep[idx];		// post-synaptic Neuron-ID

            w  =  N_weights[idx];
            atomicAdd(&N_states[snk_N + 7 * N], w);
        }
    }

}
```

### Dense-To-Sparse Conversion (./.../snn_sim_demo.cu)

```c++
void CurrentUpdater::dense_to_sparse_conversion(const bool verbose)
{
    // 2. DenseToSparse Conversion (using "write"-pointers)
    checkCusparseErrors(cusparseDenseToSparse_analysis(
        fired_handle, firing_times_dense, firing_times_sparse,
        CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, fired_buffer));

    checkCusparseErrors(cusparseDenseToSparse_convert(
        fired_handle, firing_times_dense, firing_times_sparse,
        CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, fired_buffer));
}
```

In [7]:
import time
def update(verbose=False):
    """
    Simulation steps:

    1. Voltage/Fired update (here: external kernel).
    2. DenseToSparse conversion (using "write"-pointers).
    3. "read"/"write"-pointer shifts.
    4. Synaptic current update (kernel; using "read"-pointers).
    """
    Neuron_states.tensor[7] = 25
    update_N_state(synaptic_current_updater.t)  # 1
    synaptic_current_updater.dense_to_sparse_conversion(verbose=verbose)  # 2
    synaptic_current_updater.shift_sim_pointers()  # 3
    synaptic_current_updater.update_synaptic_current()  # 4

    cuda.Context.synchronize()
    time.sleep(.1)  # cleaner prints
    print_sim_state()

In [8]:
update()

fired         = [0, 0, 0, 0].
v             = [-56.05, -56.05, -56.05, -56.05]
i             = [0.0, 0.0, 0.0, 0.0]


In [9]:
update()

fired         = [0, 0, 0, 0].
v             = [-43.46, -43.0, -43.36, -42.56]
i             = [0.0, 0.0, 0.0, 0.0]


In [10]:
update()

fired         = [0, 0, 0, 0].
v             = [-8.03, -5.15, -7.36, -2.4]
i             = [0.0, 0.0, 0.0, 0.0]


In [11]:
update()

fired         = [0, 0, 0, 0].
v             = [340.76, 398.24, 353.76, 458.18]
i             = [0.0, 0.0, 0.0, 0.0]


In [12]:
update(verbose=False)

fired         = [4, 4, 4, 4].
v             = [-63.07, -48.84, -42.2, -63.77]
i             = [0.0, 0.0, 0.0, 4.95]

firing_counts = [0 0 0 0 0 0 0 0 0 4]
firing_times  = [4.0, 4.0, 4.0, 4.0]
firing_idcs   = [0, 1, 2, 3]


In [13]:
update(verbose=False)

fired         = [0, 0, 0, 0].
v             = [-52.51, -12.64, 18.03, -52.78]
i             = [0.0, 0.0, -0.41, 0.0]

firing_counts = [0 0 0 0 0 0 0 0 0 4 0 0]
firing_times  = [4.0, 4.0, 4.0, 4.0]
firing_idcs   = [0, 1, 2, 3]


In [14]:
update(verbose=False)

fired         = [0, 0, 0, 0].
v             = [-34.3, 291.0, 1142.82, -33.55]
i             = [16.22, 4.84, 16.21, 5.47]

firing_counts = [0 0 0 0 0 0 0 0 0 4 0 0 0 0]
firing_times  = [4.0, 4.0, 4.0, 4.0]
firing_idcs   = [0, 1, 2, 3]


In [15]:
update(verbose=False)

fired         = [0, 7, 7, 0].
v             = [40.67, -38.65, -33.89, 48.14]
i             = [0.0, 0.0, 0.0, 0.0]

firing_counts = [0 0 0 0 0 0 0 0 0 4 0 0 0 0 0 2]
firing_times  = [4.0, 4.0, 4.0, 4.0, 7.0, 7.0]
firing_idcs   = [0, 1, 2, 3, 1, 2]


In [16]:
update(verbose=False)

fired         = [8, 0, 0, 8].
v             = [-60.91, 46.92, 80.1, -56.28]
i             = [0.0, 0.0, 0.0, 4.95]

firing_counts = [0 0 0 0 0 0 0 0 0 4 0 0 0 0 0 2 0 2]
firing_times  = [4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 8.0, 8.0]
firing_idcs   = [0, 1, 2, 3, 1, 2, 0, 3]


In [17]:
update(verbose=False)

fired         = [0, 9, 9, 0].
v             = [-47.08, -25.12, -17.18, -32.19]
i             = [11.05, 5.24, 10.3, 5.47]

firing_counts = [0 0 0 0 0 0 0 0 0 4 0 0 0 0 0 2 0 2 0 2]
firing_times  = [4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 8.0, 8.0, 9.0, 9.0]
firing_idcs   = [0, 1, 2, 3, 1, 2, 0, 3, 1, 2]


In [18]:
update(verbose=True)

t = 10,
n_fired                      = 6,  n_fired_m1_to_end            = 6,
n_fired_0                    = 2,  n_fired_m1                   = 0,
n_fired_total                = 10, n_fired_total_m1             = 4,
firing_counts_write (offset) = 20,
firing_idcs_read    (offset) = 4,  firing_idcs_write  (offset)  = 10,
firing_times_read   (offset) = 4,  firing_times_write (offset)  = 10.
fired         = [0, 0, 0, 0].
v             = [-15.64, 180.97, 291.87, 70.57]
i             = [5.17, -0.41, 5.5, 0.0]

firing_counts = [0 0 0 0 0 0 0 0 0 4 0 0 0 0 0 2 0 2 0 2 0 0]
firing_times  = [4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 8.0, 8.0, 9.0, 9.0]
firing_idcs   = [0, 1, 2, 3, 1, 2, 0, 3, 1, 2]
