In [1]:
%matplotlib notebook

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

In [71]:
relu = lambda v, v_th=-50., v_peak=-30., peak_rate=50.: peak_rate / (v_peak - v_th) * np.minimum(v_peak - v_th, np.maximum(0., v - v_th))

v = np.linspace(-70., 0., 100)
plt.figure()
plt.plot(v, relu(v))

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7fdfe2c4b690>]

In [229]:
from collections import defaultdict

cell_peak_rate_dict = {'E': 50., 'FFE': 50.}
cell_v_rest_dict = {'E': -70.}  # mV
cell_tau_dict = {'E': 0.05}  # s
cell_r_inp_dict = {'E': 100.}  # MOhm = mV / nA

synapse_g_dict = defaultdict(dict)
synapse_g_dict['E']['FFE'] = 0.001 / max_firing_rate_dict['FFE']  # nA / Hz

max_weight_dict = defaultdict(dict)
max_weight_dict['E']['FFE'] = 1.

synapse_v_reverse_dict = {'FFE': 0.}

synapse_tau_dict = defaultdict(dict)
synapse_tau_dict['E']['FFE'] = 0.01
num_cells_dict = {'E': 10, 'FFE': 10}

weight_init_dict = defaultdict(dict)
weight_init_dict['E']['FFE'] = np.random.uniform(0., max_weight_dict['E']['FFE'], (num_cells_dict['E'], num_cells_dict['FFE']))

num_input_patterns = 10
input_activities_dict = defaultdict(list)
for i in range(num_input_patterns):
    input_activities_dict['FFE'].append(np.random.uniform(0., cell_peak_rate_dict['FFE'], num_cells_dict['FFE']))

In [230]:
plt.figure()
plt.imshow(input_activities_dict['FFE'], aspect='auto')
plt.xlabel('FFE units')
plt.ylabel('Input pattern')
cbar = plt.colorbar()
cbar.set_label('Firing rate (Hz)', rotation=270., labelpad=20)

plt.figure()
plt.imshow(weight_init_dict['E']['FFE'], aspect='auto')
plt.xlabel('FFE units')
plt.ylabel('E units')
cbar = plt.colorbar()
cbar.set_label('Weight', rotation=270., labelpad=20)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [243]:
def flatten_state_dicts(v_dict, i_dict):
    """
    :param v_dict: dict of voltages by population; {pop_name: 1d array of float}
    :param i_dict: dict of synaptic currents by projection; {post_pop_name: {pre_pop_name: 2d array of float}}
    :return: tuple; (list of float, nested dict: tuple of int indexes)
    """

    legend = dict()
    vals = []
    legend['v'] = dict()
    start = 0
    for pop_name in sorted(list(v_dict.keys())):
        vals.extend(v_dict[pop_name])
        end = len(vals)
        legend['v'][pop_name] = (start, end)
        start = end
    legend['i'] = defaultdict(dict)
    for post_pop_name in sorted(list(i_dict.keys())):
        for pre_pop_name in sorted(list(i_dict[post_pop_name].keys())):
            vals.extend(np.ravel(i_dict[post_pop_name][pre_pop_name]))
            end = len(vals)
            legend['i'][post_pop_name][pre_pop_name] = (start, end)
            start = end
    return vals, legend


def expand_states_to_dicts(vals, legend, num_cells_dict):
    """
    :param vals: list or array of float
    :param legend: nested dict: tuple of of int indexes
    :param num_cells_dict: dict of int
    :return: tuple; nested dicts of states by population
    """
    v_dict = dict()
    i_dict = defaultdict(dict)
    for pop_name in sorted(list(legend['v'].keys())):
        start = legend['v'][pop_name][0]
        end = legend['v'][pop_name][1]
        v_dict[pop_name] = vals[start:end]
    for post_pop_name in sorted(list(legend['i'].keys())):
        for pre_pop_name in sorted(list(legend['i'][post_pop_name].keys())):
            start = legend['i'][post_pop_name][pre_pop_name][0]
            end = legend['i'][post_pop_name][pre_pop_name][1]
            i_dict[post_pop_name][pre_pop_name] = vals[start:end].reshape((num_cells_dict[post_pop_name], num_cells_dict[pre_pop_name]))
    return v_dict, i_dict
            

def get_synapse_didt_dict(v_dict, i_dict, w_dict, a_dict, synapse_tau_dict, synapse_g_dict, synapse_v_reverse_dict):
    """
    :param v_dict: dict of cell voltages by population; {pop_name: 1d array of float}
    :param i_dict: dict of synaptic currents by projection; {post_pop_name: {pre_pop_name: 2d array of float}}
    :param w_dict: dict of synaptic weights by projection; {post_pop_name: {pre_pop_name: 2d array of float}}
    :param a_dict: dict of cell activities by population; {post_pop_name: {pre_pop_name: 1d array of float}}
    :param synapse_tau_dict: dict of synaptic decay time constants by projection; {post_pop_name: {pre_pop_name: float}}
    :param synapse_g_dict: dict of synaptic factor to convert rate to current; {post_pop_name: {pre_pop_name: float}}
    :param synapse_v_reverse_dict: dict of synaptic reversal potentials; {pre_pop_name: float}
    :return: dict of didt by projection; {post_pop_name: {pre_pop_name: 2d array of float}}
    """
    didt_dict = defaultdict(dict)
    for post_pop_name in v_dict:
        v = v_dict[post_pop_name]
        for pre_pop_name in i_dict[post_pop_name]:
            w = w_dict[post_pop_name][pre_pop_name]
            a = a_dict[pre_pop_name]
            v_reverse = v_reverse_dict[pre_pop_name]
            g = synapse_g_dict[post_pop_name][pre_pop_name]
            synapse_tau = synapse_tau_dict[post_pop_name][pre_pop_name]
            i = i_dict[post_pop_name][pre_pop_name]
            didt_dict[post_pop_name][pre_pop_name] = (-i + w * g * a * (v_reverse - v)) / synapse_tau
    return didt_dict


def get_cell_dvdt_dict(v_dict, i_dict, cell_v_rest_dict, cell_tau_dict, cell_r_inp_dict):
    """
    :param v_dict: dict of cell voltages by population; {pop_name: 1d array of float}
    :param i_dict: dict of synaptic currents by projection; {post_pop_name: {pre_pop_name: 2d array of float}}
    :param cell_v_rest_dict: dict of cell resting voltages by population; {pop_name: float}
    :param cell_tau_dict: dict of cell membrane decay time constants by population; {pop_name: float}
    :param cell_r_inp_dict: dict of cell input resistance by population; {pop_name: float}
    :return: dict of dvdt by population; {pop_name: 1d array of float}
    """    
    dvdt_dict = dict()
    for post_pop_name in v_dict:
        dvdt_dict[post_pop_name] = -(v_dict[post_pop_name] - cell_v_rest_dict[post_pop_name])
        for pre_pop_name in i_dict[post_pop_name]:
            dvdt_dict[post_pop_name] += cell_r_inp_dict[post_pop_name] * np.sum(i_dict[post_pop_name][pre_pop_name], axis=1)
        dvdt_dict[post_pop_name] /= cell_tau_dict[post_pop_name]
    
    return dvdt_dict


def simulate_network(t, states, state_legend, num_cells_dict, input_pattern_index, input_activities_dict, w_dict, cell_v_rest_dict, cell_peak_rate_dict, 
                     cell_tau_dict, cell_r_inp_dict, synapse_tau_dict, synapse_g_dict, synapse_v_reverse_dict):
    """
    :param t: float
    :param states: 1d array of float
    :param state_legend: nested dict: tuple of int indexes
    :param num_cells_dict: dict of int
    :param input_pattern_index: int
    :param input_activities_dict: dict of input activities by population; {pop_name: float}
    :param w_dict: dict of synaptic weights by projection; {post_pop_name: {pre_pop_name: 2d array of float}}
    :param cell_v_rest_dict: dict of cell resting voltages by population; {pop_name: float}
    :param cell_peak_rate_dict: dict of cell peak firing rates by population; {pop_name: float}
    :param cell_tau_dict: dict of cell membrane decay time constants by population; {pop_name: float}
    :param cell_r_inp_dict: dict of cell input resistance by population; {pop_name: float}
    :param synapse_tau_dict: dict of synaptic decay time constants by projection; {post_pop_name: {pre_pop_name: float}}
    :param synapse_g_dict: dict of synaptic factor to convert rate to current; {post_pop_name: {pre_pop_name: float}}
    :param synapse_v_reverse_dict: dict of synaptic reversal potentials; {pre_pop_name: float}
    :return: 1d array of float
    """
    v_dict, i_dict = expand_states_to_dicts(states, state_legend, num_cells_dict)
    a_dict = dict()
    for pop_name in input_activities_dict:
        a_dict[pop_name] = input_activities_dict[pop_name][input_pattern_index]
    for pop_name in v_dict:
        a_dict[pop_name] = relu(v_dict[pop_name], peak_rate=cell_peak_rate_dict[pop_name])
    
    dvdt_dict = get_cell_dvdt_dict(v_dict, i_dict, cell_v_rest_dict, cell_tau_dict, cell_r_inp_dict)
    didt_dict = get_synapse_didt_dict(v_dict, i_dict, w_dict, a_dict, synapse_tau_dict, synapse_g_dict, synapse_v_reverse_dict)
    dstatesdt, _ = flatten_state_dicts(dvdt_dict, didt_dict)
    
    return dstatesdt

In [245]:
v_init_dict = {'E': np.ones(num_cells['E']) * v_rest_dict['E']}
i_init_dict = defaultdict(dict)
for post_pop_name in weight_init_dict:
    for pre_pop_name in weight_init_dict[post_pop_name]:
        i_init_dict[post_pop_name][pre_pop_name] = np.zeros_like(weight_init_dict[post_pop_name][pre_pop_name])

input_pattern_index = 0
a_init_dict = dict()
for pop_name in v_init_dict:
    a_init_dict[pop_name] = relu(v_init_dict[pop_name], peak_rate=cell_peak_rate_dict[pop_name])
for pop_name in input_activities_dict:
    a_init_dict[pop_name] = input_activities_dict[pop_name][input_pattern_index]

states, state_legend = flatten_state_dicts(v_init_dict, i_init_dict)

In [248]:
plt.figure()
plt.plot(a_init_dict['E'], label='E')
plt.plot(a_init_dict['FFE'], label='FFE')
plt.xlabel('Cell ID')
plt.ylabel('Firing rate (Hz)')
plt.title('Initial activities')
plt.legend(loc='best', frameon=False)

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7fe0051b49d0>

In [250]:
didt_dict = get_synapse_didt_dict(v_init_dict, i_init_dict, weight_init_dict, a_init_dict, synapse_tau_dict, synapse_g_dict, synapse_v_reverse_dict)

dvdt_dict = get_cell_dvdt_dict(v_init_dict, i_init_dict, cell_v_rest_dict, cell_tau_dict, cell_r_inp_dict)

In [255]:
didt_dict['E']['FFE']

array([[1.00509597e+00, 2.11945488e-01, 2.28491697e+00, 4.85396427e-01,
        4.15802296e-01, 1.49250300e+00, 1.81480692e+00, 2.31960313e-01,
        1.09737101e+00, 1.71037432e+00],
       [7.36303539e-02, 4.79977621e-01, 2.40233976e+00, 2.36562989e-02,
        4.70348229e-01, 1.11773639e+00, 1.77122033e-01, 9.23597917e-01,
        1.80158094e-03, 1.44739174e+00],
       [6.27479475e-02, 2.34202578e-01, 6.21899085e-01, 5.31029246e-01,
        9.07075444e-02, 8.94775372e-02, 1.89774499e+00, 1.11477821e-01,
        5.49588265e-01, 1.80124293e+00],
       [9.45215946e-01, 2.24776744e+00, 1.18917109e+00, 3.39513703e-02,
        4.50173906e-01, 5.00670832e-01, 8.80598916e-01, 1.02198001e+00,
        7.16947948e-01, 1.95659481e+00],
       [2.12282200e-01, 2.68494090e-01, 2.83863236e+00, 2.76316910e-02,
        9.88116328e-02, 5.72284183e-01, 1.33706851e+00, 1.02813305e+00,
        1.66379486e+00, 1.88313309e+00],
       [4.21842563e-02, 1.79444873e+00, 1.50715185e+00, 3.27084485e-01,
   

In [256]:
states_init, state_legend = flatten_state_dicts(v_init_dict, i_init_dict)
input_pattern_index = 0

duration = 1.  # sec
num_time_steps = 300
t = np.linspace(0., duration, num_time_steps)
sol = solve_ivp(simulate_network, [0., duration], states_init, 
                args=(state_legend, num_cells_dict, input_pattern_index, input_activities_dict, weight_init_dict, cell_v_rest_dict, cell_peak_rate_dict,
                      cell_tau_dict, cell_r_inp_dict, synapse_tau_dict, synapse_g_dict, v_reverse_dict), dense_output=True)

states = sol.sol(t)

In [257]:
states.shape

(110, 300)

In [258]:
print(states[:10, -1])

[-60.37410825 -63.59448495 -64.66849924 -61.11031031 -61.05357321
 -63.15670093 -60.87985054 -60.28730461 -64.61109004 -62.68045787]


In [259]:
plt.figure()

plt.imshow(states[:10,:], aspect='auto', extent=(0., duration, 10, 0))
plt.xlabel('Time (s)')
plt.ylabel('Cell ID')
plt.title('E population activity')
cbar = plt.colorbar()
cbar.set_label('Voltage (mV)', rotation=270., labelpad=20.)


<IPython.core.display.Javascript object>