In [None]:
import json
import pickle
from itertools import chain

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from neuron import h
from dlutils import utils
from dlutils.cell import Cell, branch_order
from dlutils.synapse import AMPANMDAExp2Synapse
from dlutils.spine import Spine

### Functions used to describe the removal of the Mg block from the NMDA synapse

#### Maex & De Schutter
Maex, R., & De Schutter, E. (1998). Synchronization of Golgi and granule cell firing in a detailed network model of the cerebellar granule cell layer. Journal of Neurophysiology, 80(5), 2521–2537. http://doi.org/10.1152/jn.1998.80.5.2521

#### Jahr & Stevens
Jahr, C. E., & Stevens, C. F. (1990). A quantitative description of NMDA receptor-channel kinetic behavior. The Journal of Neuroscience, 10(6), 1830–1837.

Jahr, C. E., & Stevens, C. F. (1990). Voltage dependence of NMDA-activated macroscopic conductances predicted by single-channel kinetics. The Journal of Neuroscience, 10(9), 3178–3182. http://doi.org/10.1523/JNEUROSCI.10-09-03178.1990

#### Harnett
Harnett, M. T., Makara, J. K., Spruston, N., Kath, W. L., & Magee, J. C. (2012). Synaptic amplification by dendritic spines enhances input cooperativity. Nature, 491(7425), 599–602. http://doi.org/10.1038/nature11554

There is a mistake in the definition of the Mg removal function: the units of the expression found in the paper are incorrect. The correct function is the one found here and is actually also found in the second of Jahr & Stevens's papers (equation 5).

In [None]:
extMgConc = 1
alpha_vspom = -0.062
v0_block = 10
eta = 0.2801
mg_maex_orig = lambda v: 1. / (1. + eta * extMgConc * np.exp(alpha_vspom * (v - v0_block)))

extMgConc2 = extMgConc
alpha_vspom2 = alpha_vspom * 2
v0_block2 = v0_block - 20
eta2 = eta * 0.1
mg_maex_mod = lambda v: 1. / (1. + eta2 * extMgConc2 * np.exp(alpha_vspom2 * (v - v0_block2)))

v = np.arange(-100, 50)
mg_harnett = lambda v: 1. / (1. + (extMgConc / 3.57) * np.exp(-0.062 * v))

mg_jahr_stevens_orig = lambda v: 1. / (1 + (extMgConc / 9.888) * np.exp(0.09137 * (2.222 - v)))

Kd    = 9.888   # (mM)
gamma = 0.09137 * 1.2 # (mV^-1)
sh = -15 + 2.222  # (mV)
mg_jahr_stevens_mod = lambda v: 1. / (1 + (extMgConc / Kd) * np.exp(gamma * (sh - v)))

plt.figure(figsize=(8,4))
plt.plot(v, mg_harnett(v), 'k--', lw=3, label='Harnett')
plt.plot(v, mg_maex_orig(v), color=[0,.5,0], label='Original Maex & De Schutter')
plt.plot(v, mg_maex_mod(v), color=[.5,1,.5], label='Modified Maex & De Schutter')
plt.plot(v, mg_jahr_stevens_orig(v), color=[.65,0,0], label='Original Jahr & Stevens')
plt.plot(v, mg_jahr_stevens_mod(v), color=[1,.5,.5], label='Modified Jahr & Stevens')
plt.xlabel('Voltage (mV)')
plt.ylabel('Mg unblock')
plt.legend(loc='best');

### General parameters

In [None]:
cell_type = 'thorny'
optimization_folder = '/Users/daniele/Postdoc/Research/Janelia/01_model_optimization/'

if cell_type == 'thorny':
    base_folder = optimization_folder + 'Thorny/DH070813/20191208071008_DH070813_/'
    swc_file = 'DH070813-.Edit.scaled.converted.swc'
    cell_name = 'DH070813_'
    individual = 1
else:
    base_folder = optimization_folder + 'A-thorny/DH070213C3/20191206232623_DH070213C3_/'
    swc_file = 'DH070213C3-.Edit.scaled.converted.swc'
    cell_name = 'DH070213C3_'
    individual = 0
    
swc_file = base_folder + swc_file
params_file = base_folder + 'individual_{}.json'.format(individual)
config_file = base_folder + 'parameters.json'

passive = True
with_TTX = False
replace_axon = True
add_axon_if_missing = True
parameters = json.load(open(params_file, 'r'))
mechanisms = utils.extract_mechanisms(config_file, cell_name)
sim_pars = pickle.load(open(base_folder + 'simulation_parameters.pkl','rb'))
replace_axon = sim_pars['replace_axon']
add_axon_if_missing = not sim_pars['no_add_axon']

### Instantiate the cell

In [None]:
cell = Cell('CA3_cell_%d' % int(np.random.uniform()*1e5), swc_file, parameters, mechanisms)
cell.instantiate(replace_axon, add_axon_if_missing, force_passive=passive, TTX=with_TTX)
if cell_type == 'thorny':
    section_num = 14
    section = cell.morpho.apic[section_num]
    Ra = section.Ra * 1.5
else:
    section_num = 16
    section = cell.morpho.apic[section_num]
    Ra = section.Ra
print('Branch order of section {}: {}.'.format(section.name(), branch_order(section)))

### Instantiate the spines

In [None]:
# in the Harnett paper, the head is spherical with a diameter of 0.5 um: a cylinder
# with diameter and length equal to 0.5 has the same (outer) surface area as the sphere
head_L = 0.5         # [um]
head_diam = 0.5      # [um]
neck_L = 1.58        # [um]
if cell_type == 'thorny':
    neck_diam = 0.05     # [um]
else:
    neck_diam = 0.077
spine_distance = 5   # [um] distance between neighboring spines
n_spines = 9
L = spine_distance * (n_spines - 1)
norm_L = L / section.L
if section_num == 0:
    spines = [Spine(section, x, head_L, head_diam, neck_L, neck_diam, Ra, i) \
              for i,x in enumerate(np.linspace(0.5 - norm_L/2, 0.5 + norm_L/2, n_spines))]
else:
    spines = [Spine(section, x, head_L, head_diam, neck_L, neck_diam, Ra, i) \
              for i,x in enumerate(np.linspace(0.2 - norm_L/2, 0.2 + norm_L/2, n_spines))]

for spine in spines:
    spine.instantiate()

#### Check the location of the spines in terms of distinct segments

In [None]:
segments = [section(spines[0]._sec_x)]
segments_idx = [[0]]
for i,spine in enumerate(spines[1:]):
    if section(spine._sec_x) == segments[-1]:
        segments_idx[-1].append(i+1)
    else:
        segments.append(section(spine._sec_x))
        segments_idx.append([i+1])
if len(segments_idx) == 1:
    print('All spines are connected to the same segment.')
elif len(segments_idx) == n_spines:
    print('Each spine is connected to a different segment on the dendritic branch.')
else:
    for group in segments_idx:
        if len(group) > 1:
            print('Spines {} are connected to the same segment.'.format(group))
        else:
            print('Spine {} is connected to a distinct segment.'.format(group[0]))

#### Show where the spines are located on the dendritic tree

In [None]:
plt.figure(figsize=(10,10))
for sec in chain(cell.morpho.apic, cell.morpho.basal):
    if sec in cell.morpho.apic:
        color = 'k'
    else:
        color = 'b'
    lbl = sec.name().split('.')[1].split('[')[1][:-1]
    n = sec.n3d()
    sec_coords = np.zeros((n,2))
    for i in range(n):
        sec_coords[i,:] = np.array([sec.x3d(i), sec.y3d(i)])
    middle = int(n / 2)
    plt.text(sec_coords[middle,0], sec_coords[middle,1], lbl, \
             fontsize=14, color='m')
    plt.plot(sec_coords[:,0], sec_coords[:,1], color, lw=1)
for spine in spines:
    plt.plot(spine._points[:,0], spine._points[:,1], 'r.')
plt.axis('equal');

### Insert a synapse into each spine

In [None]:
MG_MODELS = {'MDS': 1, 'HRN': 2, 'JS': 3}
Mg_unblock_model = 'JS'

E = 0        # [mV]

if cell_type == 'thorny':
    weights = np.array([0.0001, 0.0006])  # AMPA and NMDA weights
    AMPA_taus = {'tau1': 0.1, 'tau2':   1.0}
    NMDA_taus = {'tau1': 1.0, 'tau2':  50.0}
else:
    weights = np.array([0.0004, 0.0008])  # AMPA and NMDA weights
    AMPA_taus = {'tau1': 0.1, 'tau2':   1.0}
    NMDA_taus = {'tau1': 1.0, 'tau2': 100.0}

print('AMPA:')
print('    tau_rise = {:.3f} ms'.format(AMPA_taus['tau1']))
print('   tau_decay = {:.3f} ms'.format(AMPA_taus['tau2']))
print('NMDA:')
print('    tau_rise = {:.3f} ms'.format(NMDA_taus['tau1']))
print('   tau_decay = {:.3f} ms'.format(NMDA_taus['tau2']))

synapses = [AMPANMDAExp2Synapse(spine.head, 1, E, weights, AMPA = AMPA_taus, \
                                NMDA = NMDA_taus) for spine in spines]

for syn in synapses:
    syn.nmda_syn.mg_unblock_model = MG_MODELS[Mg_unblock_model]
    if Mg_unblock_model == 'MDS':
        syn.nmda_syn.alpha_vspom = alpha_vspom2
        syn.nmda_syn.v0_block = v0_block2
        syn.nmda_syn.eta = eta2
    elif Mg_unblock_model == 'JS':
        syn.nmda_syn.gamma = gamma
        syn.nmda_syn.sh = sh

if Mg_unblock_model == 'MDS':
    print('\nUsing Maex & De Schutter Mg unblock model. Modified parameters:')
    print('       alpha = {:.3f} 1/mV'.format(synapses[0].nmda_syn.alpha_vspom))
    print('    v0_block = {:.3f} mV'.format(synapses[0].nmda_syn.v0_block))
    print('         eta = {:.3f}'.format(synapses[0].nmda_syn.eta))
elif Mg_unblock_model == 'JS':
    print('\nUsing Jahr & Stevens Mg unblock model. Modified parameters:')
    print('       gamma = {:.3f} 1/mV'.format(synapses[0].nmda_syn.gamma))
    print('          sh = {:.3f} mV'.format(synapses[0].nmda_syn.sh))
elif Mg_unblock_model == 'HRN':
    print('\nUsing Harnett Mg unblock model with default parameters.')

#### Compute the presynaptic spike times

In [None]:
t0 = 1000.
dt = 1000.
spike_times = [np.sort(t0 + (n_spines - 1) * dt - np.arange(i) * dt) for i in range(n_spines, 0, -1)]
poisson = False
sequential = True
if sequential:
    for i in range(n_spines):
        for j in range(n_spines-i):
            spike_times[i][j] += i * 0.3
elif poisson:
    # frequency of the incoming spikes
    F = 1 / 0.3e-3
    for i in range(n_spines):
        ISI = - np.log(np.random.uniform(size=n_spines*2)) / F
        spks = np.cumsum(ISI)
        for j in range(n_spines - i):
            spike_times[j][i] += spks[j]

for syn, spks in zip(synapses, spike_times):
    syn.set_presynaptic_spike_times(spks)

In [None]:
spike_times

#### Make the recorders

In [None]:
rec = {}
for lbl in 't','Vsoma':
    rec[lbl] = h.Vector()
rec['t'].record(h._ref_t)
rec['Vsoma'].record(cell.morpho.soma[0](0.5)._ref_v)
for i,spine in enumerate(spines):
    rec['Vdend-{}'.format(i)] = h.Vector()
    rec['Vspine-{}'.format(i)] = h.Vector()
    rec['IAMPA-{}'.format(i)] = h.Vector()
    rec['gAMPA-{}'.format(i)] = h.Vector()
    rec['INMDA-{}'.format(i)] = h.Vector()
    rec['gNMDA-{}'.format(i)] = h.Vector()
    rec['MgBlock-{}'.format(i)] = h.Vector()
    rec['Vdend-{}'.format(i)].record(spine._sec(spine._sec_x)._ref_v)
    rec['Vspine-{}'.format(i)].record(spine.head(0.5)._ref_v)
    rec['IAMPA-{}'.format(i)].record(synapses[i].syn[0]._ref_i)
    rec['gAMPA-{}'.format(i)].record(synapses[i].syn[0]._ref_g)
    rec['INMDA-{}'.format(i)].record(synapses[i].syn[1]._ref_i)
    rec['gNMDA-{}'.format(i)].record(synapses[i].syn[1]._ref_g)
    rec['MgBlock-{}'.format(i)].record(synapses[i].syn[1]._ref_mgBlock)

#### Run the simulation

In [None]:
h.cvode_active(1)
h.tstop = t0 + n_spines * dt
h.run()

#### Get the data from the recorders

In [None]:
t = np.array(rec['t'])
iampa = np.array([np.array(rec['IAMPA-{}'.format(i)])*1e3 for i in range(n_spines)])
inmda = np.array([np.array(rec['INMDA-{}'.format(i)])*1e3 for i in range(n_spines)])
gampa = np.array([np.array(rec['gAMPA-{}'.format(i)])*1e3 for i in range(n_spines)])
gnmda = np.array([np.array(rec['gNMDA-{}'.format(i)])*1e3 for i in range(n_spines)])
MgBlock = np.array([np.array(rec['MgBlock-{}'.format(i)])for i in range(n_spines)])
Vspine = np.array([np.array(rec['Vspine-{}'.format(i)]) for i in range(n_spines)])
Vdend = np.array([np.array(rec['Vdend-{}'.format(i)]) for i in range(n_spines)])

#### Measure the amplitude ratio when the input is synaptic

In [None]:
idx, = np.where((t > 990) & (t < 1200))
EPSP_spine = np.max(Vspine[0,idx]) - Vspine[0,-1]
EPSP_dend = np.max(Vdend[0,idx]) - Vdend[0,-1]
AR = EPSP_spine / EPSP_dend
print('AR = {:.3f}'.format(AR))
plt.plot(t[idx], Vspine[0,idx], 'k', label='Spine')
plt.plot(t[idx], Vdend[0,idx], 'r', label='Dendrite')
plt.xlabel('Time (ms)')
plt.ylabel('Vm (mV)')
plt.legend(loc='best');

#### Measure the dendritic EPSPs and the NMDA conductance values

In [None]:
V_pks = np.zeros(n_spines)
dG = np.zeros(n_spines)
for i,spk in enumerate(spike_times[0]):
    idx = (t > spk) & (t < spk + 200)
    V_pks[i] = np.max(Vdend[0,idx])
    dG[i] = np.max(gnmda[0,idx] * MgBlock[0,idx])
dV = V_pks - Vdend[0,-1]

#### Plot the time course of the dendritic EPSPs and the NMDA conductance for increasing number of synaptic inputs

In [None]:
window = [50, 100]
fig,(ax1,ax2) = plt.subplots(1, 2, figsize=(12,5), sharex=True)
for spk in spike_times[0]:
    idx = (t > spk - window[0]) & (t < spk + window[1])
    ax1.plot(t[idx] - spk, Vdend[0,idx], 'k')
    ax2.plot(t[idx] - spk, gnmda[0,idx] * MgBlock[0,idx], 'k')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Dendritic voltage (mV)')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('NMDA conductance (nS)');

#### Plot the time course of the EPSP at each spine for increasing number of synaptic inputs

In [None]:
cmap = plt.get_cmap('rainbow', n_spines)
window = [50, 100]
fig,ax = plt.subplots(3, 3, figsize=(14,10), sharex=True)
for n,spk in enumerate(spike_times[0]):
    idx = (t > spk - window[0]) & (t < spk + window[1])
    for i in range(3):
        for j in range(3):
            k = i*3 + j
            lbl = '{} input'.format(n+1) if n == 0 else '{} inputs'.format(n+1)
            ax[i,j].plot(t[idx] - spk, Vspine[k, idx], color=cmap(n), label=lbl)
            ax[i,j].set_title('Spine #{}'.format(k+1))
        ax[i,0].set_ylabel('Spine head voltage (mV)')
        ax[2,i].set_xlabel('Time (ms)')
ax[0,0].legend(loc='best');

#### Plot the time course of the NMDA conductance in each spine for increasing number of synaptic inputs

In [None]:
cmap = plt.get_cmap('rainbow', n_spines)
window = [50, 100]
fig,ax = plt.subplots(3, 3, figsize=(14,10), sharex=True)
for n,spk in enumerate(spike_times[0]):
    idx = (t > spk - window[0]) & (t < spk + window[1])
    for i in range(3):
        for j in range(3):
            k = i*3 + j
            lbl = '{} input'.format(n+1) if n == 0 else '{} inputs'.format(n+1)
            ax[i,j].plot(t[idx] - spk, gnmda[k, idx] * MgBlock[k, idx], color=cmap(n), label=lbl)
            ax[i,j].set_title('Spine #{}'.format(k+1))
        ax[i,0].set_ylabel('NMDA conductance (nS)')
        ax[2,i].set_xlabel('Time (ms)')
ax[0,0].legend(loc='best');

#### Plot the amplitude of the dendritic EPSPs and of the NMDA conductance as a function of the number of synaptic inputs

In [None]:
n = 1 + np.arange(n_spines)
fig,(ax1,ax2) = plt.subplots(1, 2, figsize=(12,5), sharex=True)
ax1.plot(n, n * dV[0], 'rv-', lw=1, markerfacecolor='w', markersize=8, label='Linear prediction')
ax1.plot(n, dV, 'ko-', lw=1, markerfacecolor='w', markersize=7, label='Measured')
ax2.plot(n, dG, 'ko-', lw=1, markerfacecolor='w', markersize=7, label='Measured')
ax1.set_xlabel('Input number')
ax1.set_ylabel('Dendrite EPSP (mV)')
ax2.set_xlabel('Input number')
ax2.set_ylabel('NMDA conductance (nS)')
ax1.legend(loc='best')
ax1.set_ylim([0, np.max([dV[-1], dV[0]*n_spines]) * 1.1])
ax2.set_ylim([dG[0] - 0.1, dG[-1] + 0.1]);