# Load packages

In [None]:
import numpy as np
import matplotlib.pyplot as plt
# color-blind color scheme
plt.style.use('tableau-colorblind10')

## Load in old code

We build on the previous notebook:

In [None]:
class LIF_neuron:
    # initialize a neuron class
    # provided parameter dictionary params
    def __init__(self, params):
        # attach parameters to object
        self.V_th, self.V_reset = params['V_th'], params['V_reset']   
        self.tau_m, self.g_L = params['tau_m'], params['g_L']        
        self.V_init, self.V_L = params['V_init'], params['V_L']       
        self.dt = params['dt']
        self.tau_ref = params['tau_ref']

        # initialize voltage and current
        self.v = 0.0
        # time steps since last spike
        self.refractory_counter = 0
    
    def LIF_step(self, I):
        """
            Perform one step of the LIF dynamics
        """
        
        currently_spiking = False
        
        if self.refractory_counter > 0:
            # if the neuron is still refractory
            self.v = self.V_reset
            self.refractory_counter = self.refractory_counter - 1
        elif self.v >= self.V_th:
            # if v is above threshold,
            # reset voltage and record spike event
            currently_spiking = True
            self.v = self.V_reset
            self.refractory_counter = self.tau_ref/self.dt
        else:
            # else, integrate the current:
            # calculate the increment of the membrane potential
            dv = self.voltage_dynamics(I)
            # update the membrane potential
            self.v = self.v + dv

        return self.v, currently_spiking
    
    def voltage_dynamics(self, I):
        """
            Calulcates one step of the LI dynamics
        """
        dv = (-(self.v-self.V_L) + I/self.g_L) * (self.dt/self.tau_m)
        return dv
        

In [None]:
# define new class as child of old class
class ExpLIF_neuron(LIF_neuron):
    def __init__(self, params):
        # build on LIF neuron with same settings
        # (this will run __init__ of the parent class)
        super().__init__(params)
        
        # we only need to attach additional variables:
        self.DeltaT = params['DeltaT']
        self.V_exp_trigger = params['V_exp_trigger']
    
    # now we can just    
    def voltage_dynamics(self, I):
        """
            Calulcates one step of the exp-LI dynamics
        """
        dv = (-(self.v-self.V_L) + I/self.g_L + self.DeltaT * np.exp((self.v-self.V_exp_trigger)/self.DeltaT)) * (self.dt/self.tau_m)
        return dv
        

In [None]:
class ExpLIF_population:
    def __init__(self, params):
        # attach parameters to object
        self.V_th, self.V_reset = params['V_th'], params['V_reset']   
        self.tau_m, self.g_L = params['tau_m'], params['g_L']        
        self.V_init, self.V_L = params['V_init'], params['V_L']       
        self.dt = params['dt']
        self.tau_ref = params['tau_ref']
        self.DeltaT = params['DeltaT']
        self.V_exp_trigger = params['V_exp_trigger']
        
        # number of neurons
        self.n_neurons = params["n_neurons"]

        # initialize voltages
        self.v = np.zeros(self.n_neurons)
        # time steps since last spike
        self.refractory_counter = np.zeros(self.n_neurons)
            
    def LIF_step(self, I):
        """
            Perform one step of the LIF dynamics
        """
        
        currently_spiking = np.array([False for _ in range(self.n_neurons)])
        
        # This is where the magic happens: numpy indexing.
        # first, we need to get indices of neurons which
        # are refractory, above threshold or neither:
        idx_ref = np.where(self.refractory_counter > 0)[0]
        idx_spk = np.where(self.v > self.V_th)[0]
        idx_else = np.where((self.refractory_counter <= 0) & (self.v <= self.V_th))[0]
        
        # if the neuron is still refractory
        self.v[idx_ref] = self.V_reset
        self.refractory_counter[idx_ref] -= 1
        
        # if v is above threshold,
        # reset voltage and record spike event
        currently_spiking[idx_spk] = True
        self.v[idx_spk] = self.V_reset
        self.refractory_counter[idx_spk] = self.tau_ref/self.dt
        
        # calculate the increment of the membrane potential
        dv = self.voltage_dynamics(I)
        # update the membrane potential only for non-spiking neurons
        self.v[idx_else] += dv[idx_else]

        return self.v, currently_spiking
        
    def voltage_dynamics(self, I):
        """
            Calulcates one step of the exp-LI dynamics
        """
        # Fortunately, this code already enabled vectors, due to numpy magic.
        dv = (-(self.v-self.V_L) + I/self.g_L + self.DeltaT * np.exp((self.v-self.V_exp_trigger)/self.DeltaT)) * (self.dt/self.tau_m)
        return dv
        

In [None]:
params = {}
### typical neuron parameters###
params['V_th']    = -55. # spike threshold [mV]
params['V_reset'] = -75. #reset potential [mV]
params['tau_m']   = 10. # membrane time constant [ms]
params['g_L']     = 10. #leak conductance [nS]
params['V_init']  = -65. # initial potential [mV]
params['V_L']     = -75. #leak reversal potential [mV]
params['tau_ref']    = 2. # refractory time (ms)
params['dt'] = .1  # Simulation time step [ms]

# additional parameters for ExpLIF neurons
params['DeltaT'] = 10.0  # sharpness of exponential peak
params['V_exp_trigger'] = -55. # threshold for exponential depolarization [mV]
params['V_th'] = 0 # new reset threshold [mV]

# Timeit and memory allocation

The code we wrote works, but there several things that we can do to make it run faster. Let's also time it using the Jupyter cell magic `%%timeit`:

In [None]:
params["n_neurons"] = 1_000
params["n_steps"] = 10_000

mean_I, std_I = 300, 300

### Single neuron code

In [None]:
# I have commented these out because they can run quite slow

# population1 = [ExpLIF_neuron(params) for _ in range(n_neurons)]

# # these will now become lists of lists (neurons, time steps)
# voltages_arr = []
# spikes_arr = []

In [None]:
# %%timeit -n 1 -r 3

# for i, neuron in enumerate(population1):
#     voltages = []
#     spikes = []
# #     if i % 10 == 0:
# #         print(f"Working on neuron {i}")
#     for _ in range(params["n_steps"]):
#         I = np.random.normal(mean_I, std_I)
#         v, s = neuron.LIF_step(I=I)
#         voltages.append(v)
#         spikes.append(s)
#     voltages_arr.append(voltages.copy())
#     spikes_arr.append(spikes.copy())

### Population code

In [None]:
population2 = ExpLIF_population(params)

# these will now become lists of lists (neurons, time steps)
voltages_arr = []
spikes_arr = []

In [None]:
%%timeit -n 1 -r 10

for _ in range(params["n_steps"]):
    I = np.random.normal(mean_I, std_I, size=params["n_neurons"])
    v, s = population2.LIF_step(I=I)
    voltages_arr.append(v.copy())
    spikes_arr.append(s.copy())

An important trick is knowing about memory allocation: by `append`ing to the lists, we are always creating a new object in memory.

**For large arrays, this becomes very slow**.

But because we know beforehand how long each simulation is, we can create the list beforehand and write into it during simulation:

In [None]:
population3 = ExpLIF_population(params)

# these will now become lists of lists (neurons, time steps)
voltages_arr = np.zeros((params["n_steps"], params["n_neurons"]))
spikes_arr = np.zeros((params["n_steps"], params["n_neurons"]))

In [None]:
%%timeit -n 1 -r 10

for i in range(params["n_steps"]):
    I = np.random.normal(mean_I, std_I, size=params["n_neurons"])
    voltages_arr[i], spikes_arr[i] = population3.LIF_step(I=I)

(Actually, our example here is so small that you will barely see a difference; but for large arrays, I've seend a difference of 300% in simulation speed)


Let's look at some spike rasters and a histogram:

In [None]:
x_range = (9000,10_000)
for i in range(params["n_neurons"]):
    spike_times = spikes_arr[x_range[0]:x_range[1],i].nonzero()[0]
    plt.scatter(spike_times + x_range[0], i*np.ones_like(spike_times), marker='.', c='black')
plt.xlabel('Time step')
plt.ylabel('# Neuron')
plt.show()

# Realizing when it is time to go from a Jupyter notebook to a standalone script

As we have seen, the simulations are becoming larger, with more populations and variables which get overwritten. 

Jupyter notebooks are great for prototyping, but at some point, we have to switch to a proper script. Some reasons are:
- reproducibility and debugging: a common issue with Jupyter notebooks or similar IDEs is that you don't notice when old variables are in use.
You may have restarted the kernel and noticed some plot looks different, without an obvious way to backtrack what was different in the previous execution.
- clarity: scripts can be organized more easily into modules, making it easier to understand which parts being called at a given time.
- version control: Jupyter notebooks can be a headache for collaborations. When you execute a notebook, the IDs of all cells change, even if you haven't actually modified their content. `Git` is not able to tell these apart, and your collaborators/future you will have to dig through every line to see if something has actually changed.

Take a look at `standalone_script.py`, which implements the populations using a parameter file `params.yaml`. Familiarize yourself with both. You can run it with `python standalone_script.py`.

# Multiprocessing

So far, we are using numpy in its simplest form: a single process running on CPU. We can take advantage of multicore systems by using multiprocessing:

In [None]:
import multiprocess as mp   # for multiprocessing

# declare the number of processes to start
N_PROCESSES = 4

As a simple example, let's use the squaring of a list:

In [None]:
def f(x):
    return x**2

In [None]:
with mp.Pool(N_PROCESSES) as pool:
    output = pool.map(f, [1, 2, 3])

output

Great, that works. But what if we want to have a more general f(x), like being able to choose the exponent?

In [None]:
def g(x, n):
    return x**n

Let's calulcate $1^3, 2^3, 3^3$, i.e. $n=3$ for all cases. In the above example, this could be our parameter set `params`.

We might assume that we can pass a tuple or a list, but this fails:

In [None]:
with mp.Pool(N_PROCESSES) as pool:
    output = pool.map(g, [(1,2), (2,3), (3,3)])

output

Instead, we need to wrap our function into a partial function:

In [None]:
import functools

partial_run = functools.partial(g, n=3) # this instantiates a copy of g with one argument less
partial_run

In [None]:
with mp.Pool(N_PROCESSES) as pool:
    output = pool.map(partial_run, [1, 2, 3])

output

**Final task:** implement multiprocessing into `standalone_script.py`. To do so, divide the total population into `N_PROCESSES` subpopulations, and run these in parallel.