In [4]:
import random
import numpy as np

# Define the pre- and post-synaptic spike times
pre_spike_times = [10, 20, 30]
post_spike_times = [5, 15, 25, 35]

# Initialize the synaptic weight
weight = 0.5

# Set the STDP parameters
tau_plus = 20
tau_minus = 20
a_plus = 0.1
a_minus = 0.1

# Define the STDP function
def stdp(pre_spike_time, post_spike_time, weight):
    if pre_spike_time < post_spike_time:
        weight += a_plus * np.exp(-(post_spike_time - pre_spike_time) / tau_plus)
    elif pre_spike_time > post_spike_time:
        weight -= a_minus * np.exp(-(pre_spike_time - post_spike_time) / tau_minus)
    return weight

# Update the synaptic weight based on pre- and post-synaptic spike times
for i in range(len(pre_spike_times)):
    for j in range(len(post_spike_times)):
        weight = stdp(pre_spike_times[i], post_spike_times[j], weight)

print("Final synaptic weight: ", weight)

Final synaptic weight:  0.5


The runtime of the code is O(N^2), where N is the total number of pre- and post-synaptic spike times. This is because the code contains a nested loop that iterates over all possible combinations of pre- and post-synaptic spike times, resulting in N * N total iterations.



# Optimization
there are several techniques that can be used to optimize the implementation of STDP.

###  Vectorization: 
Instead of iterating over each pair of pre- and post-synaptic spike times, we can use NumPy arrays to vectorize the calculation of the STDP function. This allows us to perform the calculation on all pairs of spikes simultaneously, which can significantly speed up the computation. Here is an example of how to vectorize the STDP calculation:

In [5]:
import numpy as np

# Define the pre- and post-synaptic spike times
pre_spike_times = np.array([10, 20, 30])
post_spike_times = np.array([5, 15, 25, 35])

# Initialize the synaptic weight
weight = 0.5

# Set the STDP parameters
tau_plus = 20
tau_minus = 20
a_plus = 0.1
a_minus = 0.1

# Calculate the time differences between pre- and post-synaptic spikes
delta_t = np.subtract.outer(pre_spike_times, post_spike_times)

# Calculate the weight change for each pair of spikes
weight_change = np.where(delta_t < 0, a_plus * np.exp(delta_t / tau_plus), -a_minus * np.exp(-delta_t / tau_minus))

# Sum the weight changes to get the final weight
weight += np.sum(weight_change)

print("Final synaptic weight: ", weight)

Final synaptic weight:  0.5


### parallelization
The idea behind parallelization is to split the computation across multiple processor cores or machines, so that multiple calculations can be performed simultaneously. This can significantly reduce the overall runtime of the computation.

In [7]:
import numpy as np

# Define the pre- and post-synaptic spike times
pre_spike_times = [10, 20, 30]
post_spike_times = [5, 15, 25, 35]

# Initialize the synaptic weight
weight = 0.5

# Set the STDP parameters
tau_plus = 20
tau_minus = 20
a_plus = 0.1
a_minus = 0.1

# Define the STDP function
def stdp(pre_spike_time, post_spike_time):
    global weight
    delta_t = post_spike_time - pre_spike_time
    if delta_t > 0:
        weight += a_plus * np.exp(-delta_t / tau_plus)
    elif delta_t < 0:
        weight -= a_minus * np.exp(delta_t / tau_minus)

# Compute the weight changes for each pair of spikes
for pre_spike_time in pre_spike_times:
    for post_spike_time in post_spike_times:
        stdp(pre_spike_time, post_spike_time)

print("Final synaptic weight: ", weight)

Final synaptic weight:  0.5


###  GPU acceleration
This involves using a graphics processing unit (GPU) to perform the computations in parallel, which can provide a significant speedup compared to running the computations on a CPU. The TensorFlow and PyTorch libraries provide easy ways to perform GPU-accelerated computations in Python.

You need this: 

pip install torch torchvision torchaudio -f https://download.pytorch.org/whl/cu111/torch_stable.html

In [8]:
import torch

# Define the pre- and post-synaptic spike times
pre_spike_times = [10, 20, 30]
post_spike_times = [5, 15, 25, 35]

# Initialize the synaptic weight on the GPU
weight = torch.tensor(0.5, dtype=torch.float32, device="cuda")

# Set the STDP parameters
tau_plus = 20
tau_minus = 20
a_plus = 0.1
a_minus = 0.1

# Define the STDP function on the GPU
@torch.jit.script
def stdp(pre_spike_time, post_spike_time, weight):
    delta_t = post_spike_time - pre_spike_time
    if delta_t > 0:
        weight += a_plus * torch.exp(-delta_t / tau_plus)
    elif delta_t < 0:
        weight -= a_minus * torch.exp(delta_t / tau_minus)
    return weight

# Compute the weight changes for each pair of spikes on the GPU
for pre_spike_time in pre_spike_times:
    for post_spike_time in post_spike_times:
        weight = stdp(torch.tensor(pre_spike_time, dtype=torch.float32, device="cuda"), torch.tensor(post_spike_time, dtype=torch.float32, device="cuda"), weight)

# Move the final synaptic weight back to the CPU
weight = weight.cpu().detach().numpy()

print("Final synaptic weight: ", weight)

AssertionError: Torch not compiled with CUDA enabled