In [None]:
import os
import json
import pickle
from itertools import chain

import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
%matplotlib inline

from neuron import h
from dlutils import utils
from dlutils.cell import Cell, branch_order
from dlutils.synapse import AMPANMDAExp2Synapse
from dlutils.spine import Spine

### General parameters

In [None]:
optimization_folder = '/Users/daniele/Postdoc/Research/Janelia/01_model_optimization/'

# config_file = 'synaptic_cooperativity_thorny_passive.json'
# config_file = 'synaptic_cooperativity_thorny_active_with_TTX.json'
# config_file = 'synaptic_cooperativity_thorny_active.json'
# config_file = 'synaptic_cooperativity_a-thorny_passive.json'
# config_file = 'synaptic_cooperativity_a-thorny_active_with_TTX.json'
config_file = 'synaptic_cooperativity_a-thorny_active.json'
config = json.load(open(config_file, 'r'))

cell_type = config['cell_type']
prefix = cell_type[0].upper() + cell_type[1:]
base_folder = optimization_folder + prefix + '/' + config['cell_name'] + '/' + config['optimization_run'] + '/'
swc_file = config['swc_file']
cell_name = config['cell_name'] + '_'
individual = config['individual']

swc_file = base_folder + swc_file
params_file = base_folder + 'individual_{}.json'.format(individual)
config_file = base_folder + 'parameters.json'

passive = config['passive']
with_TTX = config['with_TTX']
replace_axon = True
add_axon_if_missing = True
parameters = json.load(open(params_file, 'r'))
mechanisms = utils.extract_mechanisms(config_file, cell_name)
sim_pars = pickle.load(open(base_folder + 'simulation_parameters.pkl','rb'))
replace_axon = sim_pars['replace_axon']
add_axon_if_missing = not sim_pars['no_add_axon']

### Functions used to describe the removal of the Mg block from the NMDA synapse

#### Maex & De Schutter
Maex, R., & De Schutter, E. (1998). Synchronization of Golgi and granule cell firing in a detailed network model of the cerebellar granule cell layer. Journal of Neurophysiology, 80(5), 2521–2537. http://doi.org/10.1152/jn.1998.80.5.2521

#### Jahr & Stevens
Jahr, C. E., & Stevens, C. F. (1990). A quantitative description of NMDA receptor-channel kinetic behavior. The Journal of Neuroscience, 10(6), 1830–1837.

Jahr, C. E., & Stevens, C. F. (1990). Voltage dependence of NMDA-activated macroscopic conductances predicted by single-channel kinetics. The Journal of Neuroscience, 10(9), 3178–3182. http://doi.org/10.1523/JNEUROSCI.10-09-03178.1990

#### Harnett
Harnett, M. T., Makara, J. K., Spruston, N., Kath, W. L., & Magee, J. C. (2012). Synaptic amplification by dendritic spines enhances input cooperativity. Nature, 491(7425), 599–602. http://doi.org/10.1038/nature11554

There is a mistake in the definition of the Mg removal function: the units of the expression found in the paper are incorrect. The correct function is the one found here and is actually also found in the second of Jahr & Stevens's papers (equation 5).

In [None]:
v = np.arange(-100, 50)
extMgConc = config['NMDA']['extMgConc']

mg_maex_orig = lambda v: 1. / (1. + 0.2801 * extMgConc * np.exp(-0.062 * (v - 10)))

mg_jahr_stevens_orig = lambda v: 1. / (1 + (extMgConc / 9.888) * np.exp(0.09137 * (2.222 - v)))

mg_harnett = lambda v: 1. / (1. + (extMgConc / 3.57) * np.exp(-0.062 * v))

plt.figure(figsize=(8,4))
plt.plot(v, mg_harnett(v), 'k--', lw=3, label='Harnett')
plt.plot(v, mg_maex_orig(v), color=[0,.5,0], label='Original Maex & De Schutter')

if config['NMDA']['model'] == 'MDS':
    alpha_vspom = config['NMDA']['alpha_vspom'] # -0.124
    v0_block    = config['NMDA']['v0_block']    # -10
    eta         = config['NMDA']['eta']         # 0.02801
    mg_maex_mod = lambda v: 1. / (1. + eta * extMgConc * np.exp(alpha_vspom * (v - v0_block)))
    plt.plot(v, mg_maex_mod(v), color=[.5,1,.5], label='Modified Maex & De Schutter')

plt.plot(v, mg_jahr_stevens_orig(v), color=[.65,0,0], label='Original Jahr & Stevens')

if config['NMDA']['model'] == 'JS':
    Kd    = config['NMDA']['Kd']    # (mM)
    gamma = config['NMDA']['gamma'] # (mV^-1)
    sh    = config['NMDA']['sh']    # (mV)
    mg_jahr_stevens_mod = lambda v: 1. / (1 + (extMgConc / Kd) * np.exp(gamma * (sh - v)))
    plt.plot(v, mg_jahr_stevens_mod(v), color=[1,.5,.5], label='Modified Jahr & Stevens')

plt.xlabel('Voltage (mV)')
plt.ylabel('Mg unblock')
plt.legend(loc='best');

### Instantiate the cell

In [None]:
cell = Cell('CA3_cell_%d' % int(np.random.uniform()*1e5), swc_file, parameters, mechanisms)
cell.instantiate(replace_axon, add_axon_if_missing, force_passive=passive, TTX=with_TTX)
section_num = config['section_num']
section = cell.morpho.apic[section_num]
Ra = section.Ra * config['Ra_neck_coeff']
print('Branch order of section {}: {}.'.format(section.name(), branch_order(section)))

### Instantiate the spines

In [None]:
# in the Harnett paper, the head is spherical with a diameter of 0.5 um: a cylinder
# with diameter and length equal to 0.5 has the same (outer) surface area as the sphere
head_L = config['spine']['head_L']           # [um]
head_diam = config['spine']['head_diam']     # [um]
neck_L = config['spine']['neck_L']           # [um]
neck_diam = config['spine']['neck_diam']     # [um]
spine_distance = config['spine_distance']    # [um] distance between neighboring spines
n_spines = config['n_spines']                # number of spines
L = spine_distance * (n_spines - 1)
norm_L = L / section.L
if section_num == 0 or (cell_type == 'thorny' and section_num == 10):
    spines = [Spine(section, x, head_L, head_diam, neck_L, neck_diam, Ra, i) \
              for i,x in enumerate(np.linspace(0.5 - norm_L/2, 0.5 + norm_L/2, n_spines))]
else:
    spines = [Spine(section, x, head_L, head_diam, neck_L, neck_diam, Ra, i) \
              for i,x in enumerate(np.linspace(0.2 - norm_L/2, 0.2 + norm_L/2, n_spines))]

for spine in spines:
    spine.instantiate()

#### Check the location of the spines in terms of distinct segments

In [None]:
segments = [section(spines[0]._sec_x)]
segments_idx = [[0]]
for i,spine in enumerate(spines[1:]):
    if section(spine._sec_x) == segments[-1]:
        segments_idx[-1].append(i+1)
    else:
        segments.append(section(spine._sec_x))
        segments_idx.append([i+1])
if len(segments_idx) == 1:
    print('All spines are connected to the same segment.')
elif len(segments_idx) == n_spines:
    print('Each spine is connected to a different segment on the dendritic branch.')
else:
    for group in segments_idx:
        if len(group) > 1:
            print('Spines {} are connected to the same segment.'.format(group))
        else:
            print('Spine {} is connected to a distinct segment.'.format(group[0]))

#### Show where the spines are located on the dendritic tree

In [None]:
plt.figure(figsize=(10,10))
for sec in chain(cell.morpho.apic, cell.morpho.basal):
    if sec in cell.morpho.apic:
        color = 'k'
    else:
        color = 'b'
    lbl = sec.name().split('.')[1].split('[')[1][:-1]
    n = sec.n3d()
    sec_coords = np.zeros((n,2))
    for i in range(n):
        sec_coords[i,:] = np.array([sec.x3d(i), sec.y3d(i)])
    middle = int(n / 2)
    plt.text(sec_coords[middle,0], sec_coords[middle,1], lbl, \
             fontsize=14, color='m')
    plt.plot(sec_coords[:,0], sec_coords[:,1], color, lw=1)
for spine in spines:
    plt.plot(spine._points[:,0], spine._points[:,1], 'r.')
plt.axis('equal');

### Insert a synapse into each spine

In [None]:
MG_MODELS = {'MDS': 1, 'HRN': 2, 'JS': 3}
Mg_unblock_model = config['NMDA']['model']

E = 0        # [mV]

AMPA_taus = config['AMPA']['time_constants']
NMDA_taus = config['NMDA']['time_constants']
weights = np.array([config['AMPA']['weight'], config['NMDA']['weight']])

print('AMPA:')
print('    tau_rise = {:.3f} ms'.format(AMPA_taus['tau1']))
print('   tau_decay = {:.3f} ms'.format(AMPA_taus['tau2']))
print('NMDA:')
print('    tau_rise = {:.3f} ms'.format(NMDA_taus['tau1']))
print('   tau_decay = {:.3f} ms'.format(NMDA_taus['tau2']))

synapses = [AMPANMDAExp2Synapse(spine.head, 1, E, weights, AMPA = AMPA_taus, \
                                NMDA = NMDA_taus) for spine in spines]

for syn in synapses:
    syn.nmda_syn.mg_unblock_model = MG_MODELS[Mg_unblock_model]
    if Mg_unblock_model == 'MDS':
        syn.nmda_syn.alpha_vspom = config['NMDA']['alpha_vspom']
        syn.nmda_syn.v0_block = config['NMDA']['v0_block']
        syn.nmda_syn.eta = config['NMDA']['eta']
    elif Mg_unblock_model == 'JS':
        syn.nmda_syn.Kd = config['NMDA']['Kd']
        syn.nmda_syn.gamma = config['NMDA']['gamma']
        syn.nmda_syn.sh = config['NMDA']['sh']

if Mg_unblock_model == 'MDS':
    print('\nUsing Maex & De Schutter Mg unblock model. Modified parameters:')
    print('       alpha = {:.3f} 1/mV'.format(synapses[0].nmda_syn.alpha_vspom))
    print('    v0_block = {:.3f} mV'.format(synapses[0].nmda_syn.v0_block))
    print('         eta = {:.3f}'.format(synapses[0].nmda_syn.eta))
elif Mg_unblock_model == 'JS':
    print('\nUsing Jahr & Stevens Mg unblock model. Modified parameters:')
    print('          Kd = {:.3f} 1/mV'.format(synapses[0].nmda_syn.Kd))
    print('       gamma = {:.3f} 1/mV'.format(synapses[0].nmda_syn.gamma))
    print('          sh = {:.3f} mV'.format(synapses[0].nmda_syn.sh))
elif Mg_unblock_model == 'HRN':
    print('\nUsing Harnett Mg unblock model with default parameters.')

#### Make the recorders

In [None]:
rec = {}
for lbl in 't','Vsoma':
    rec[lbl] = h.Vector()
rec['t'].record(h._ref_t)
rec['Vsoma'].record(cell.morpho.soma[0](0.5)._ref_v)
for i,spine in enumerate(spines):
    rec['Vdend-{}'.format(i)] = h.Vector()
    rec['Vspine-{}'.format(i)] = h.Vector()
    rec['IAMPA-{}'.format(i)] = h.Vector()
    rec['gAMPA-{}'.format(i)] = h.Vector()
    rec['INMDA-{}'.format(i)] = h.Vector()
    rec['gNMDA-{}'.format(i)] = h.Vector()
    rec['MgBlock-{}'.format(i)] = h.Vector()
    rec['Vdend-{}'.format(i)].record(spine._sec(spine._sec_x)._ref_v)
    rec['Vspine-{}'.format(i)].record(spine.head(0.5)._ref_v)
    rec['IAMPA-{}'.format(i)].record(synapses[i].syn[0]._ref_i)
    rec['gAMPA-{}'.format(i)].record(synapses[i].syn[0]._ref_g)
    rec['INMDA-{}'.format(i)].record(synapses[i].syn[1]._ref_i)
    rec['gNMDA-{}'.format(i)].record(synapses[i].syn[1]._ref_g)
    rec['MgBlock-{}'.format(i)].record(synapses[i].syn[1]._ref_mgBlock)
    
# these additional recorders will record the voltage from the site of the spine closest
# to the soma all the way to the soma, to see how the voltage attenuates as it travels
# towards the cell body
prop_rec = []
prop_rec_ina = []
prop_rec_ik = []
prop_rec_ica = []
prop_rec_cai = []
seg_area = []
sec = section
while sec != cell.morpho.soma[0]:
    prop_rec.append([])
    prop_rec_ina.append([])
    prop_rec_ik.append([])
    prop_rec_ica.append([])
    prop_rec_cai.append([])
    seg_area.append([])
    for seg in sec:
        prop_rec[-1].append(h.Vector())
        prop_rec[-1][-1].record(seg._ref_v)
        prop_rec_ina[-1].append(h.Vector())
        prop_rec_ina[-1][-1].record(seg._ref_ina)
        prop_rec_ik[-1].append(h.Vector())
        prop_rec_ik[-1][-1].record(seg._ref_ik)
        prop_rec_ica[-1].append(h.Vector())
        prop_rec_ica[-1][-1].record(seg._ref_ica)
        prop_rec_cai[-1].append(h.Vector())
        prop_rec_cai[-1][-1].record(seg._ref_cai)
        seg_area[-1].append(seg.area())
        if sec == section and seg == section(spines[0]._sec_x):
            break
    sec = sec.parentseg().sec
prop_rec.append([h.Vector()])
prop_rec[-1][-1].record(cell.morpho.soma[0](0.5)._ref_v)
prop_rec_ina.append([h.Vector()])
prop_rec_ina[-1][-1].record(cell.morpho.soma[0](0.5)._ref_ina)
prop_rec_ik.append([h.Vector()])
prop_rec_ik[-1][-1].record(cell.morpho.soma[0](0.5)._ref_ik)
prop_rec_ica.append([h.Vector()])
prop_rec_ica[-1][-1].record(cell.morpho.soma[0](0.5)._ref_ica)
prop_rec_cai.append([h.Vector()])
prop_rec_cai[-1][-1].record(cell.morpho.soma[0](0.5)._ref_cai)
seg_area.append([cell.morpho.soma[0](0.5).area()])

#### Compute the presynaptic spike times

In [None]:
t0 = 1000.
dt = 1000.

single_trial = True
with_somatic_current_injection = False

if with_somatic_current_injection:
    spike_times = [np.sort(t0 + n_spines * dt - np.arange(i) * dt) for i in range(n_spines+1, 0, -1)]
    spike_times = spike_times[:-1]
else:
    spike_times = [np.sort(t0 + (n_spines - 1) * dt - np.arange(i) * dt) for i in range(n_spines, 0, -1)]
poisson = False
sequential = True
if sequential:
    for i in range(n_spines):
        for j in range(len(spike_times[i])):
            spike_times[i][j] += i * 0.3
elif poisson:
    # frequency of the incoming spikes
    F = 1 / 0.3e-3
    for i in range(n_spines):
        ISI = - np.log(np.random.uniform(size=n_spines*2)) / F
        spks = np.cumsum(ISI)
        for j in range(n_spines - i):
            spike_times[j][i] += spks[j]

if single_trial:
    spike_times = [np.array([spks[-1] - spike_times[0][-1] + t0]) for spks in spike_times]
    
for syn, spks in zip(synapses, spike_times):
    syn.set_presynaptic_spike_times(spks)
    
if with_somatic_current_injection:
    stim['soma'].amp = 0.06
    stim['soma'].dur = 100
    stim['soma'].delay = spike_times[0][-1]

In [None]:
spike_times

#### Run the simulation

In [None]:
h.cvode_active(1)
h.tstop = spike_times[0][-1] + dt
if with_somatic_current_injection:
    h.tstop += dt
h.run()

#### Get the data from the recorders

In [None]:
t = np.array(rec['t'])
iampa = np.array([np.array(rec['IAMPA-{}'.format(i)])*1e3 for i in range(n_spines)])
inmda = np.array([np.array(rec['INMDA-{}'.format(i)])*1e3 for i in range(n_spines)])
gampa = np.array([np.array(rec['gAMPA-{}'.format(i)])*1e3 for i in range(n_spines)])
gnmda = np.array([np.array(rec['gNMDA-{}'.format(i)])*1e3 for i in range(n_spines)])
MgBlock = np.array([np.array(rec['MgBlock-{}'.format(i)])for i in range(n_spines)])
Vspine = np.array([np.array(rec['Vspine-{}'.format(i)]) for i in range(n_spines)])
Vdend = np.array([np.array(rec['Vdend-{}'.format(i)]) for i in range(n_spines)])
Vsoma = np.array(rec['Vsoma'])

prop_V = [[np.array(v) for v in grp] for grp in prop_rec]
prop_INa = [[np.array(v) for v in grp] for grp in prop_rec_ina]
prop_IK = [[np.array(v) for v in grp] for grp in prop_rec_ik]
prop_ICa = [[np.array(v) for v in grp] for grp in prop_rec_ica]
prop_Cai = [[np.array(v) for v in grp] for grp in prop_rec_cai]

#### Plot the results

In [None]:
window = [5,100]
idx, = np.where((t > spike_times[0][-1] - window[0]) & (t < spike_times[0][-1] + window[1]))
plt.plot(t[idx], prop_V[0][0][idx], 'r', linewidth=2, label='Dendrite')
plt.plot(t[idx], prop_V[-1][0][idx], 'k', linewidth=1, label='Soma')
plt.xlabel('Time (ms)')
plt.ylabel('Vm (mV)')
plt.legend(loc='upper right');

In [None]:
n_rec = len([0 for grp in prop_rec for v in grp])
colors = [[0,0,0], [1,0,0], [0,1,0], [0,0,1], [0,1,1], [1,0,1], [1,1,0]]
cmap = plt.get_cmap('hsv', len(prop_V))
fig,ax = plt.subplots(5, 1, figsize=(5,7), sharex=True)
for i in range(len(prop_V)):
    base_col = colors[i]
    if i == len(prop_V) - 1:
        lw = 2
    else:
        lw = 1
    for j in range(len(prop_INa[i])):
        col = np.min([[1,1,1], base_col + j * 0.2 * np.ones(3)], axis=0)
        ax[0].plot(t[idx], prop_V[i][j][idx], color=col, linewidth=lw)
        ax[4].plot(t[idx], prop_IK[i][j][idx], color=col, linewidth=lw)
        ax[3].plot(t[idx], prop_INa[i][j][idx], color=col, linewidth=lw)
        ax[2].plot(t[idx], prop_ICa[i][j][idx], color=col, linewidth=lw)
        ax[1].plot(t[idx], prop_Cai[i][j][idx]*1e6, color=col, linewidth=lw)
for side in ('right','top'):
    for a in ax:
        a.spines[side].set_visible(False)
ax[-1].set_xlabel('Time (ms)')
ax[0].set_ylabel('Vm (mV)')
ax[4].set_ylabel('I_K (mA/cm2)')
ax[3].set_ylabel('I_Na (mA/cm2)')
ax[2].set_ylabel('I_Ca (mA/cm2)')
ax[1].set_ylabel('Ca_i (nM)')
plt.tick_params(labelsize=8)
fig.tight_layout()
pdf_file = 'burst_exp_decaying_na'
if not os.path.isfile(pdf_file + '_1.pdf'):
    plt.savefig(pdf_file + '_1.pdf', bbox_inches='tight')

In [None]:
fig,ax = plt.subplots(5, 1, figsize=(5,7), sharex=True)
for i in range(len(prop_V)):
    base_col = colors[i]
    if i == len(prop_V) - 1:
        lw = 2
    else:
        lw = 1
    for j in range(len(prop_INa[i])):
        col = np.min([[1,1,1], base_col + j * 0.2 * np.ones(3)], axis=0)
        ax[0].plot(t[idx], prop_V[i][j][idx], color=col, linewidth=lw)
        ax[4].plot(t[idx], prop_IK[i][j][idx] * seg_area[i][j] * 1e-2, color=col, linewidth=lw)
        ax[3].plot(t[idx], prop_INa[i][j][idx] * seg_area[i][j] * 1e-2, color=col, linewidth=lw)
        ax[2].plot(t[idx], prop_ICa[i][j][idx] * seg_area[i][j] * 1e-2, color=col, linewidth=lw)
        ax[1].plot(t[idx], prop_Cai[i][j][idx]*1e6, color=col, linewidth=lw)
for side in ('right','top'):
    for a in ax:
        a.spines[side].set_visible(False)
ax[-1].set_xlabel('Time (ms)')
ax[0].set_ylabel('Vm (mV)')
ax[4].set_ylabel('I_K (nA)')
ax[3].set_ylabel('I_Na (nA)')
ax[2].set_ylabel('I_Ca (nA)')
ax[1].set_ylabel('Ca_i (nM)')
plt.tick_params(labelsize=8)
fig.tight_layout()
if not os.path.isfile(pdf_file + '_2.pdf'):
    plt.savefig(pdf_file + '_2.pdf', bbox_inches='tight')

#### Save the data