# Multi-Spike Tempotron Learning Algorithm for Computing the Theta Critical Gradient 

## 1. Execute all cells containing functions 

In [28]:
import numpy as np
from scipy import optimize
rtol = 1e-5 # rtol from numpy.isclose

In [5]:
def gen_omega(n, omega_coefficient):
    np.random.seed(100000)
    omega = np.random.random(n)*omega_coefficient
    return omega

In [None]:
def desired_fea(n_fea, desired_n, cond = 0, multi = 0):
    fea_num = np.zeros(n_fea)
    for i in range(desired_n):
        if cond:
            if multi:
                fea_num[i] = multi
            else:
                fea_num[i] = i + 1
        else:
            fea_num[i] = 1
    return fea_num

In [6]:
def STS(unweighted_input, theta, omega):
    spike_time = []
    unresetted_V = []
    datalen = unweighted_input.shape[1]
    V = (omega[:,np.newaxis] * unweighted_input).sum(axis=0)
    done = False
    while not done:
        spike_idx = find_first_spike(V, theta)
        if spike_idx == -1:
            done = True
        else:
            spike_time.append(spike_idx)
            unresetted_V.append(V[spike_idx])
            mem_len = min(ref_memory_len, datalen - spike_idx)
            V[spike_idx:spike_idx+mem_len] -= theta * ref_kernel[:mem_len]
    return spike_time, unresetted_V, max(V)

In [8]:
def bisect(unweighted_input, theta_spikes, theta_range, smallest_range, omega):
    theta_list = [theta_range[0], sum(theta_range)/2, theta_range[1]]
    high = STS(unweighted_input, theta_range[0], omega)
    medium = STS(unweighted_input, sum(theta_range)/2, omega)
    low = STS(unweighted_input, theta_range[1], omega)
    spike_list = [len(high[0]), len(medium[0]), len(low[0])]
    if theta_list[2] - theta_list[0] >= smallest_range:
        if theta_spikes <= spike_list[0] and theta_spikes > spike_list[1]:
            theta_range = [theta_list[0], theta_list[1]]
        elif theta_spikes <= spike_list[1] and theta_spikes > spike_list[2]:
            theta_range = [theta_list[1], theta_list[2]]
        mid_theta, mid_spike = bisect(unweighted_input, theta_spikes, theta_range, smallest_range, omega)
        return sorted(np.append(theta_list, mid_theta[1:-1]), reverse=True), sorted(np.append(spike_list, mid_spike[1:-1])) 
    else:
        return theta_list, spike_list

In [9]:
def root_fun(theta_range, unweighted_input, omega):
    V_max = STS(unweighted_input, theta_range[0], omega)[2]
    return [theta_range[0] - V_max, V_max - theta_range[1]]

In [10]:
def theta_critical(unweighted_input, num_spikes, desired_spikes, theta_range, smallest_range, omega):
    if num_spikes > desired_spikes:
        theta_spikes = num_spikes
    elif num_spikes < desired_spikes:
        theta_spikes = num_spikes + 1
    theta_list, spike_list = bisect(unweighted_input, theta_spikes, theta_range, smallest_range, omega)
    theta_critical_range = optimize.root(root_fun, [theta_list[spike_list.index(theta_spikes)-1], theta_list[spike_list.index(theta_spikes)]], args=(unweighted_input, omega)).x 
    spike_time, unresetted_V, max_V = STS(unweighted_input, theta_critical_range[1] * (1 - rtol), omega)
    t_star_list = spike_time[:unresetted_V.index(min(unresetted_V))+1]
    return theta_critical_range[1], t_star_list

In [11]:
### Optimized voltage calculation code by Alex ###

def get_memory_len(kernel_array, ratio):
    """
    Return the number of time bins until kernel_array has decreased by the factor 'ratio'
    'kernel_array' may initially rise, but must be monotonically decreasing from 
    the maximum.
    """
    arr = (kernel_array - ratio*kernel_array.max())[::-1]
        # The point where this array reaches zero is the desired memory time
        # We flip the order with [::-1] because np.searchsorted expects increasing order
    memory_len = len(kernel_array) - np.searchsorted(arr, 0)
    return memory_len

def find_first_spike(V, threshold):
    # Based on the equivalent code to py_find_1st (https://github.com/roebel/py_find_1st)
    ind = np.flatnonzero(V >= threshold)
    if len(ind):
        return ind[0]
    else:
        return -1

def calculate_Vt(unweighted_input, theta, omega):
    spike_time = []
    datalen = unweighted_input.shape[1]
    V = (omega[:,np.newaxis] * unweighted_input).sum(axis=0)
    done = False
    while not done:
        spike_idx = find_first_spike(V, theta)
        if spike_idx == -1:
            done = True
        else:
            spike_time.append(spike_idx)
            mem_len = min(ref_memory_len, datalen - spike_idx)
            V[spike_idx:spike_idx+mem_len] -= theta * ref_kernel[:mem_len]
    return V, spike_time


In [12]:
def C_tx(spike_time):    #eqn 29
    C = np.zeros(len(spike_time))
    count = 0
    for spike in spike_time:
        C[count] = 1 + 1
        for idx in range(count):
            if spike_time[idx] > spike - ref_memory_len:
                C[count] += ref_kernel[spike - spike_time[idx]] 
        count += 1
    return C

In [13]:
def dVtx_dwi(unweighted_input, spike_time, C):    #eqn 30
    kernel_t = unweighted_input[:, spike_time]
    return kernel_t/C

In [14]:
def dVtx_dtsk(tot_input, C, spike_time, x, k):    #eqn 31
    return (-tot_input[spike_time[x]]/C[x]**2)*(np.exp(-(spike_time[x]-spike_time[k])/tau_mem)/tau_mem)


In [15]:
def dif_kernel(length, tau_mem, tau_syn, time_ij):
    time = np.arange(0., length, 1.) #ms
    kernel = np.zeros(length)
    eta = tau_mem/tau_syn
    V_norm = eta**(eta/(eta-1))/(eta-1)
    for count in range(length):
        kernel[count] = V_norm*(-(np.exp(-(time[count]-time_ij)/tau_mem))/tau_mem+(np.exp(-(time[count]-time_ij)/tau_syn))/tau_syn)  
    return kernel

In [16]:
def dV0t_dt(data, diff_kernel):    #eqn 32 first part
    diff_kernel_array = np.zeros((n, data.shape[1]))
    for neuron, ith_bin in zip(*np.where(data)):
        kernel_len = min(len(diff_kernel), data.shape[1] - ith_bin)
        diff_kernel_array[neuron, ith_bin:ith_bin+kernel_len] += np.multiply(diff_kernel, omega[neuron])[:kernel_len]
    return diff_kernel_array.sum(axis=0)

In [17]:
def t_derivative(sum_diff_kernel, tot_input, C, spike_time, x):    #eqn 32
    return sum_diff_kernel[spike_time[x]]/C[x] + tot_input[spike_time[x]]*sum(np.exp(-(spike_time[x]-spike_time[j])/tau_mem) for j in range(x))/(tau_mem*C[x]**2)


In [18]:
def A_cache(n_spikes, sum_diff_kernel, tot_input, C, spike_time):
    A_cache = np.zeros(n_spikes)
    for k in range(n_spikes):
        A_cache[k] = fn_A(sum_diff_kernel, tot_input, A_cache, C, spike_time, k)
    return A_cache

In [19]:
def fn_A(sum_diff_kernel, tot_input, A, C, spike_time, k):    #eqn 23
    return 1 - sum(A[j]/t_derivative(sum_diff_kernel, tot_input, C, spike_time, j)*dVtx_dtsk(tot_input, C, spike_time, k, j) for j in range(k))


In [20]:
def B_cache(n_spikes, sum_diff_kernel, tot_input, dVt_dwi, C, spike_time):
    B_cache = np.zeros(n_spikes)
    for k in range(n_spikes):
        B_cache[k] = fn_B(sum_diff_kernel, tot_input, B_cache, C, spike_time, dVt_dwi, k)
    return B_cache

In [21]:
def fn_B(sum_diff_kernel, tot_input, B, C, spike_time, dVt_dwi, k):    #eqn 24
    if len(dVt_dwi) != 0:
        return -dVt_dwi[k] - sum(B[j]/t_derivative(sum_diff_kernel, tot_input, C, spike_time, j)*dVtx_dtsk(tot_input, C, spike_time, k, j) for j in range(k))
    else:
        return 0 - sum(B[j]/t_derivative(sum_diff_kernel, tot_input, C, spike_time, j)*dVtx_dtsk(tot_input, C, spike_time, k, j) for j in range(k))


In [22]:
def theta_grad_i(A_cache, B_cache, star):   #Eqn 27
    return -B_cache[star]/A_cache[star]

In [23]:
def theta_critical_grad(n, sum_diff_kernel, tot_input, dVt_dw, A, C, spike_time, n_spikes, star):
    grad = np.zeros(n)
    for count in range(n):
        dVt_dwi = dVt_dw[count]
        B = B_cache(n_spikes, sum_diff_kernel, tot_input, dVt_dwi, C, spike_time)
        grad[count] += theta_grad_i(A, B, star)
    return grad

In [24]:
def update_w(omega, learning_rate, desired_spikes, num_spikes, sum_diff_kernel, tot_input, dVt_dw, A, C, spike_time, n_spikes, star):
    if num_spikes > desired_spikes:
        return omega - learning_rate*theta_critical_grad(n, sum_diff_kernel, tot_input, dVt_dw, A, C, spike_time, n_spikes, star)
    elif num_spikes < desired_spikes:
        return omega + learning_rate*theta_critical_grad(n, sum_diff_kernel, tot_input, dVt_dw, A, C, spike_time, n_spikes, star)
    else:
        return omega

In [25]:
def multispike_training(current_omega, n_cycles, data_sets, fea_idx, learning_rate):
    cur_omega_list = []
    for cycle in range(n_cycles):
        for trial in range(cycle*100, cycle*100+100):
            idx = int(np.random.random()*data_sets)
            input_data = np.load("data/data_"+str(idx)+".npz")
            data, presyn_input, markers, n_fea_occur, fea_time, fea_order = input_data['arr_0'], input_data['arr_1'], input_data['arr_2'], input_data['arr_3'], input_data['arr_4'], input_data['arr_5']
            desired_spikes = sum(n_fea_occur * fea_idx)
            V_t, spike_time = calculate_Vt(presyn_input, theta, current_omega)
            tot_input = (current_omega[:,np.newaxis] * presyn_input).sum(axis=0)
            num_spikes = len(spike_time)
            if num_spikes != desired_spikes:
                theta_star, tx_list = theta_critical(presyn_input, num_spikes, desired_spikes, theta_range, 0.00001, current_omega)
                C = C_tx(tx_list)
                dVt_dw = dVtx_dwi(presyn_input, tx_list, C)
                sum_diff_kernel = dV0t_dt(data, diff_kernel)
                A = A_cache(len(tx_list), sum_diff_kernel, tot_input, C, tx_list)
                t_star = len(tx_list)-1
                current_omega = update_w(current_omega, learning_rate, desired_spikes, num_spikes, sum_diff_kernel, tot_input, dVt_dw, A, C, tx_list, len(tx_list), t_star)
        cur_omega_list.append(current_omega)
    return cur_omega_list

## 2. Training Procedures for Multi-Spike Tempotron

### 2.1 Pre-set Parameters

Points to consider before executing the cell below:

a. Omega_coefficient of 0.022 was chosen based on the average firing rate of approximately 5 Hz for the postsynaptic neuron (multi-spike tempotron) with initial synaptic efficacies (omega).

b. Theta_range was determined based on the omega_coefficient, it is recommended to reset the theta_range if the omega_coefficient is changed to avoid potential "out-of-range" errors.

c. The effective dif_kernel_len was selected based on the membrane integration time constant (tau_mem) and synaptic decay time constant (tau_syn). It is highly recommended to determine the effective dif_kernel_len again if tau_mem and tau_syn are changed.

In [27]:
n = 500 # number of input neurons
omega_coefficient = 0.022
theta = 1 # threshold
theta_range = [0.5, 1.5] # initial threshold range for bisection
tau_mem = 20 # ms
tau_syn = 5 # ms
time_ij = 0 # ms
dif_kernel_len = 150 # length of differentiated kernel (postsynaptic potential --> filter), negligible tail is truncated

# Generate the initial omega for each input neuron
omega = gen_omega(n, omega_coefficient)

# Create the refractory kernel, and then remove the negligible tail
ref_kernel = np.exp(- np.arange(1000) / tau_mem)
ref_memory_len = get_memory_len(ref_kernel, ratio=0.001)
# synaptic memory length
ref_kernel = ref_kernel[:ref_memory_len]

# Generate the differentiated kernel
diff_kernel = dif_kernel(dif_kernel_len, tau_mem, tau_syn, time_ij)

### 2.2 Training 

The multispike_training function takes 5 arguments, which are current_omega, n_cycles, data_sets, fea_idx, and learning_rate, as inputs and return a list of updated omegas. The updated omega list stores the updated omega after every cycle of training (100 trials each). 

current_omega: intial omega
n_cycles: number of training cycles required
data_sets: available number of precomputed data sets in the pool
fea_idx: indices of features that the multispike tempotron has to identify
learning_rate: the size of update after each error trial

Notes:

The learning_rate of 0.00001 was chosen according to the original publication (Robert Gütig, 2016).

### 2.2.1 Identification of one feature with one spike 

The desired_fea function is used to generate a list of desired feature numbers.

To generate one spike to identify only one feature, set desired_n to 1 and cond to 0.

In [None]:
n_cycles = 1000
data_sets = 19999
learning_rate = 0.00001
n_fea = 10 # Total number of features and distractors
desired_n = 1 # Number of desired features

# To generate a list of desired feature indices
fea_idx = desired_fea(n_fea, desired_n, cond = 0, multi = 0) 

latest_omega1F1S_list = multispike_training(omega, n_cycles, data_sets, fea_idx, learning_rate)
np.save("latest_omega1F1S_list", latest_omega1F1S_list)

### 2.2.2 Identification of one feature with multiple spikes

To generate multiple spikes to identify one feature, set desired_n to 1, cond to 1, and multi to the desired number of spikes.

In [None]:
n_cycles = 1000
data_sets = 19999
learning_rate = 0.00001
n_fea = 10 # Total number of features and distractors
desired_n = 1 # Number of desired features

# To generate a list of desired feature indices
fea_idx = desired_fea(n_fea, desired_n, cond = 1, multi = 5) 

latest_omega1F5S_list = multispike_training(omega, n_cycles, data_sets, fea_idx, learning_rate)
np.save("latest_omega1F5S_list", latest_omega1F5S_list)

### 2.2.3 Identification of multiple features with one spike per feature

To generate one spike to identify multiple features, set desired_n to the number of desired features and cond to 0.

In [None]:
n_cycles = 1000
data_sets = 19999
learning_rate = 0.00001
n_fea = 10 # Total number of features and distractors
desired_n = 5 # Number of desired features

# To generate a list of desired feature indices
fea_idx = desired_fea(n_fea, desired_n, cond = 0, multi = 0) 

latest_omega5F1S_list = multispike_training(omega, n_cycles, data_sets, fea_idx, learning_rate)
np.save("latest_omega5F1S_list", latest_omega5F1S_list)

### 2.2.4 Identification of multiple features with different number of spikes for different features

To generate different number of spikes to identify different features, set desired_n to the number of desired features, cond to 1, and multi to 0.

In [None]:
n_cycles = 1000
data_sets = 19999
learning_rate = 0.00001
n_fea = 10 # Total number of features and distractors
desired_n = 5 # Number of desired features

# To generate a list of desired feature indices
fea_idx = desired_fea(n_fea, desired_n, cond = 1, multi = 0) 

latest_omega5FmS_list = multispike_training(omega, n_cycles, data_sets, fea_idx, learning_rate)
np.save("latest_omega5FmS_list", latest_omega5FmS_list)

### 2.2.5 Identification of multiple features with a fixed number of spikes (> 1) for each feature

To generate a fixed number of spikes that is greater than 1 for multiple features, set desired_n to the number of desired features, cond to 1, and multi to desired number of spikes.

In [None]:
n_cycles = 1000
data_sets = 19999
learning_rate = 0.00001
n_fea = 10 # Total number of features and distractors
desired_n = 5 # Number of desired features

# To generate a list of desired feature indices
fea_idx = desired_fea(n_fea, desired_n, cond = 1, multi = 5) 

latest_omega5F5S_list = multispike_training(omega, n_cycles, data_sets, fea_idx, learning_rate)
np.save("latest_omega5F5S_list", latest_omega5F5S_list)