In [None]:
import os
import json
import pickle
from itertools import chain

import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
from scipy.signal import find_peaks
import matplotlib
import matplotlib.pyplot as plt
from plotly.colors import qualitative
%matplotlib inline

from neuron import h
from dlutils import utils
from dlutils.cell import Cell, branch_order
from dlutils.synapse import AMPANMDAExp2Synapse
from dlutils.spine import Spine

#### A couple of functions useful for recording arbitrary ion currents in a segment

In [None]:
def make_mechanism_recorders(segment, suffix, gbar, Erev, gating_variables, calcium_current=False, with_h2=False):
    from numbers import Number
    mech_rec = {
        'Erev': Erev if isinstance(Erev, Number) else getattr(segment, Erev),
        'gbar': getattr(segment, gbar + '_' + suffix),
        'vars': {'voltage': {'vec': h.Vector()}},
        'calcium_current': calcium_current,
        'with_h2': with_h2
    }
    mech_rec['vars']['voltage']['vec'].record(segment._ref_v)
    if calcium_current:
        for c in 'cai','cao':
            mech_rec['vars'][c] = {'vec': h.Vector()}
            mech_rec['vars'][c]['vec'].record(getattr(segment, '_ref_' + c))
    for var,expon in gating_variables.items():
        mech_rec['vars'][var] = {'vec': h.Vector(), 'expon': expon}
        mech_rec['vars'][var]['vec'].record(getattr(segment, '_ref_' + var + '_' + suffix))
    return mech_rec


def compute_current(mech_rec):
    def goldman_hodgkin_katz(v, cai, cao, celsius=h.celsius):
        def efun(x):
            small = np.abs(x) < 1e-4
            not_small = np.logical_not(small)
            y = np.zeros(x.shape)
            y[small] = 1 - x[small] / 2
            y[not_small] = x[not_small] / (np.exp(x[not_small]) - 1)
            return y
        f = ((25.0 / 293.15) * (celsius + 273.15)) / 2
        nu = v / f
        return -f * (1. - (cai / cao) * np.exp(nu)) * efun(nu)

    v = np.array(mech_rec['vars']['voltage']['vec'])
    if mech_rec['calcium_current']:
        cai = np.array(mech_rec['vars']['cai']['vec'])
        cao = np.array(mech_rec['vars']['cao']['vec']) 
    g = mech_rec['gbar']
    for key,value in mech_rec['vars'].items():
        if key not in ('voltage', 'cai', 'cao'):
            vec = np.array(value['vec'])
            g *= vec ** value['expon']
    if mech_rec['calcium_current']:
        if mech_rec['with_h2']:
            g *= (1e-3 / (1e-3 + cai))            
        return g * goldman_hodgkin_katz(v, cai, cao)
    return g * (v - mech_rec['Erev'])

### General parameters

In [None]:
config_file = 'synaptic_inputs_active_thorny.json'
config_file = 'synaptic_inputs_active_a-thorny.json'
config = json.load(open(config_file, 'r'))

optimization_folder = config['optimization_folder']
cell_type = config['cell_type']
prefix = cell_type[0].upper() + cell_type[1:]
base_folder = optimization_folder + prefix + '/' + config['cell_name'] + '/' + config['optimization_run'] + '/'
swc_file = config['swc_file']
cell_name = config['cell_name'] + '_'
individual = config['individual']

swc_file = base_folder + swc_file
params_file = base_folder + 'individual_{}.json'.format(individual)
config_file = base_folder + 'parameters.json'

passive = False
with_TTX = False
replace_axon = True
add_axon_if_missing = True
parameters = json.load(open(params_file, 'r'))
mechanisms = utils.extract_mechanisms(config_file, cell_name)
sim_pars = pickle.load(open(base_folder + 'simulation_parameters.pkl','rb'))
replace_axon = sim_pars['replace_axon']
add_axon_if_missing = not sim_pars['no_add_axon']

### Functions used to describe the removal of the Mg block from the NMDA synapse

#### Maex & De Schutter
Maex, R., & De Schutter, E. (1998). Synchronization of Golgi and granule cell firing in a detailed network model of the cerebellar granule cell layer. Journal of Neurophysiology, 80(5), 2521–2537. http://doi.org/10.1152/jn.1998.80.5.2521

#### Jahr & Stevens
Jahr, C. E., & Stevens, C. F. (1990). A quantitative description of NMDA receptor-channel kinetic behavior. The Journal of Neuroscience, 10(6), 1830–1837.

Jahr, C. E., & Stevens, C. F. (1990). Voltage dependence of NMDA-activated macroscopic conductances predicted by single-channel kinetics. The Journal of Neuroscience, 10(9), 3178–3182. http://doi.org/10.1523/JNEUROSCI.10-09-03178.1990

#### Harnett
Harnett, M. T., Makara, J. K., Spruston, N., Kath, W. L., & Magee, J. C. (2012). Synaptic amplification by dendritic spines enhances input cooperativity. Nature, 491(7425), 599–602. http://doi.org/10.1038/nature11554

There is a mistake in the definition of the Mg removal function: the units of the expression found in the paper are incorrect. The correct function is the one found here and is actually also found in the second of Jahr & Stevens's papers (equation 5).

In [None]:
v = np.arange(-100, 50)
extMgConc = config['NMDA']['extMgConc']

mg_maex_orig = lambda v: 1. / (1. + 0.2801 * extMgConc * np.exp(-0.062 * (v - 10)))

mg_jahr_stevens_orig = lambda v: 1. / (1 + (extMgConc / 9.888) * np.exp(0.09137 * (2.222 - v)))

mg_harnett = lambda v: 1. / (1. + (extMgConc / 3.57) * np.exp(-0.062 * v))

plt.figure(figsize=(8,4))
plt.plot(v, mg_harnett(v), 'k--', lw=3, label='Harnett')
plt.plot(v, mg_maex_orig(v), color=[0,.5,0], label='Original Maex & De Schutter')

if config['NMDA']['model'] == 'MDS':
    alpha_vspom = config['NMDA']['alpha_vspom'] # -0.124
    v0_block    = config['NMDA']['v0_block']    # -10
    eta         = config['NMDA']['eta']         # 0.02801
    mg_maex_mod = lambda v: 1. / (1. + eta * extMgConc * np.exp(alpha_vspom * (v - v0_block)))
    plt.plot(v, mg_maex_mod(v), color=[.5,1,.5], label='Modified Maex & De Schutter')

plt.plot(v, mg_jahr_stevens_orig(v), color=[.65,0,0], label='Original Jahr & Stevens')

if config['NMDA']['model'] == 'JS':
    Kd    = config['NMDA']['Kd']    # (mM)
    gamma = config['NMDA']['gamma'] # (mV^-1)
    sh    = config['NMDA']['sh']    # (mV)
    mg_jahr_stevens_mod = lambda v: 1. / (1 + (extMgConc / Kd) * np.exp(gamma * (sh - v)))
    plt.plot(v, mg_jahr_stevens_mod(v), color=[1,.5,.5], label='Modified Jahr & Stevens')

plt.xlabel('Voltage (mV)')
plt.ylabel('Mg unblock')
plt.legend(loc='best');

### Instantiate the cell

In [None]:
cell = Cell('CA3_cell_%d' % int(np.random.uniform()*1e5), swc_file, parameters, mechanisms)
cell.instantiate(replace_axon, add_axon_if_missing, force_passive=passive, TTX=with_TTX)
section_num = config['section_num']
section = cell.morpho.apic[section_num]
Ra = section.Ra * config['Ra_neck_coeff']
print('Branch order of section {}: {}.'.format(section.name(), branch_order(section)))

### Instantiate the spines

In [None]:
# in the Harnett paper, the head is spherical with a diameter of 0.5 um: a cylinder
# with diameter and length equal to 0.5 has the same (outer) surface area as the sphere
head_L = config['spine']['head_L']           # [um]
head_diam = config['spine']['head_diam']     # [um]
neck_L = config['spine']['neck_L']           # [um]
neck_diam = config['spine']['neck_diam']     # [um]
spine_distance = config['spine_distance']    # [um] distance between neighboring spines
n_spines = config['n_spines']                # number of spines
L = spine_distance * (n_spines - 1)
norm_L = L / section.L

spine_loc = config['spine_loc']
start, stop = spine_loc + norm_L/2 * np.array([-1,1])
if start < 0:
    start = 0
    stop = start + norm_L
if stop > 1:
    stop = 1
    start = stop - norm_L
spines = [Spine(section, x, head_L, head_diam, neck_L, neck_diam, Ra, i) \
            for i,x in enumerate(np.linspace(start, stop, n_spines))]

for spine in spines:
    spine.instantiate()

#### Check the location of the spines in terms of distinct segments

In [None]:
segments = [section(spines[0]._sec_x)]
segments_idx = [[0]]
for i,spine in enumerate(spines[1:]):
    if section(spine._sec_x) == segments[-1]:
        segments_idx[-1].append(i+1)
    else:
        segments.append(section(spine._sec_x))
        segments_idx.append([i+1])
if len(segments_idx) == 1:
    print('All spines are connected to the same segment.')
elif len(segments_idx) == n_spines:
    print('Each spine is connected to a different segment on the dendritic branch.')
else:
    for group in segments_idx:
        if len(group) > 1:
            print('Spines {} are connected to the same segment.'.format(group))
        else:
            print('Spine {} is connected to a distinct segment.'.format(group[0]))

#### Show where the spines are located on the dendritic tree

In [None]:
plt.figure(figsize=(10,10))
for sec in chain(cell.morpho.apic, cell.morpho.basal):
    if sec in cell.morpho.apic:
        color = 'k'
    else:
        color = 'b'
    lbl = sec.name().split('.')[1].split('[')[1][:-1]
    n = sec.n3d()
    sec_coords = np.zeros((n,2))
    for i in range(n):
        sec_coords[i,:] = np.array([sec.x3d(i), sec.y3d(i)])
    middle = int(n / 2)
    plt.text(sec_coords[middle,0], sec_coords[middle,1], lbl, \
             fontsize=14, color='m')
    plt.plot(sec_coords[:,0], sec_coords[:,1], color, lw=1)
for spine in spines:
    plt.plot(spine._points[:,0], spine._points[:,1], 'r.')
plt.axis('equal');

### Insert a synapse into each spine

In [None]:
MG_MODELS = {'MDS': 1, 'HRN': 2, 'JS': 3}
Mg_unblock_model = config['NMDA']['model']

E = 0        # [mV]

AMPA_taus = config['AMPA']['time_constants']
NMDA_taus = config['NMDA']['time_constants']
weights = np.array([config['AMPA']['weight'], config['NMDA']['weight']])

print('AMPA:')
print('    tau_rise = {:.3f} ms'.format(AMPA_taus['tau1']))
print('   tau_decay = {:.3f} ms'.format(AMPA_taus['tau2']))
print('NMDA:')
print('    tau_rise = {:.3f} ms'.format(NMDA_taus['tau1']))
print('   tau_decay = {:.3f} ms'.format(NMDA_taus['tau2']))

synapses = [AMPANMDAExp2Synapse(spine.head, 1, E, weights, AMPA = AMPA_taus, \
                                NMDA = NMDA_taus) for spine in spines]

for syn in synapses:
    syn.nmda_syn.mg_unblock_model = MG_MODELS[Mg_unblock_model]
    if Mg_unblock_model == 'MDS':
        syn.nmda_syn.alpha_vspom = config['NMDA']['alpha_vspom']
        syn.nmda_syn.v0_block = config['NMDA']['v0_block']
        syn.nmda_syn.eta = config['NMDA']['eta']
    elif Mg_unblock_model == 'JS':
        syn.nmda_syn.Kd = config['NMDA']['Kd']
        syn.nmda_syn.gamma = config['NMDA']['gamma']
        syn.nmda_syn.sh = config['NMDA']['sh']

if Mg_unblock_model == 'MDS':
    print('\nUsing Maex & De Schutter Mg unblock model. Modified parameters:')
    print('       alpha = {:.3f} 1/mV'.format(synapses[0].nmda_syn.alpha_vspom))
    print('    v0_block = {:.3f} mV'.format(synapses[0].nmda_syn.v0_block))
    print('         eta = {:.3f}'.format(synapses[0].nmda_syn.eta))
elif Mg_unblock_model == 'JS':
    print('\nUsing Jahr & Stevens Mg unblock model. Modified parameters:')
    print('          Kd = {:.3f} 1/mV'.format(synapses[0].nmda_syn.Kd))
    print('       gamma = {:.3f} 1/mV'.format(synapses[0].nmda_syn.gamma))
    print('          sh = {:.3f} mV'.format(synapses[0].nmda_syn.sh))
elif Mg_unblock_model == 'HRN':
    print('\nUsing Harnett Mg unblock model with default parameters.')

#### Make the recorders

In [None]:
rec = {}
for lbl in 't','Vsoma','spike_times':
    rec[lbl] = h.Vector()
rec['t'].record(h._ref_t)
rec['Vsoma'].record(cell.morpho.soma[0](0.5)._ref_v)

apc = h.APCount(cell.morpho.soma[0](0.5))
apc.thresh = -20
apc.record(rec['spike_times'])

for i,spine in enumerate(spines):
    rec['Vdend-{}'.format(i)] = h.Vector()
    rec['Vspine-{}'.format(i)] = h.Vector()
    rec['Vdend-{}'.format(i)].record(spine._sec(spine._sec_x)._ref_v)
    rec['Vspine-{}'.format(i)].record(spine.head(0.5)._ref_v)

# these additional recorders will record the voltage from the site of the spine closest
# to the soma all the way to the soma, to see how the voltage attenuates as it travels
# towards the cell body
prop_rec = []
prop_rec_ina = []
prop_rec_ik = []
prop_rec_ica = []
prop_rec_cai = []
seg_area = []
sec = section
while sec != cell.morpho.soma[0]:
    prop_rec.append([])
    prop_rec_ina.append([])
    prop_rec_ik.append([])
    prop_rec_ica.append([])
    prop_rec_cai.append([])
    seg_area.append([])
    for seg in sec:
        prop_rec[-1].append(h.Vector())
        prop_rec[-1][-1].record(seg._ref_v)
        prop_rec_ina[-1].append(h.Vector())
        prop_rec_ina[-1][-1].record(seg._ref_ina)
        prop_rec_ik[-1].append(h.Vector())
        prop_rec_ik[-1][-1].record(seg._ref_ik)
        prop_rec_ica[-1].append(h.Vector())
        prop_rec_ica[-1][-1].record(seg._ref_ica)
        prop_rec_cai[-1].append(h.Vector())
        prop_rec_cai[-1][-1].record(seg._ref_cai)
        seg_area[-1].append(seg.area())
        if sec == section and seg == section(spines[0]._sec_x):
            break
    sec = sec.parentseg().sec
prop_rec.append([h.Vector()])
prop_rec[-1][-1].record(cell.morpho.soma[0](0.5)._ref_v)
prop_rec_ina.append([h.Vector()])
prop_rec_ina[-1][-1].record(cell.morpho.soma[0](0.5)._ref_ina)
prop_rec_ik.append([h.Vector()])
prop_rec_ik[-1][-1].record(cell.morpho.soma[0](0.5)._ref_ik)
prop_rec_ica.append([h.Vector()])
prop_rec_ica[-1][-1].record(cell.morpho.soma[0](0.5)._ref_ica)
prop_rec_cai.append([h.Vector()])
prop_rec_cai[-1][-1].record(cell.morpho.soma[0](0.5)._ref_cai)
seg_area.append([cell.morpho.soma[0](0.5).area()])

In [None]:
soma_seg = cell.morpho.soma[0](0.5)
dend_seg = spines[0]._sec(spines[0]._sec_x)
mech_rec = {
    'soma': {
        'nax':  make_mechanism_recorders(soma_seg, 'nax', 'gbar', 'ena', {'m': 3, 'h': 1}),
        'nap':  make_mechanism_recorders(soma_seg, 'nap', 'gnabar', 'ena', {'n': 3}),
        'kdr':  make_mechanism_recorders(soma_seg, 'kdr', 'gkdrbar', 'ek', {'n': 1}),
        'kap':  make_mechanism_recorders(soma_seg, 'kap', 'gkabar', 'ek', {'n': 1, 'l': 1}),
        'kmb':  make_mechanism_recorders(soma_seg, 'kmb', 'gbar', 'ek', {'m': 1}),
        'kca':  make_mechanism_recorders(soma_seg, 'kca', 'gbar', 'ek', {'m': 3}),
        'cagk': make_mechanism_recorders(soma_seg, 'cagk', 'gbar', 'ek', {'o': 1}),
        'hd':   make_mechanism_recorders(soma_seg, 'hd', 'ghdbar', -30, {'l': 1}),
        'cat':  make_mechanism_recorders(soma_seg, 'cat', 'gcatbar', 'eca', {'m': 2, 'h': 1},
                                         calcium_current=True, with_h2=False),
        'cal':  make_mechanism_recorders(soma_seg, 'cal', 'gcalbar', 'eca', {'m': 2},
                                         calcium_current=True, with_h2=True),
        'can':  make_mechanism_recorders(soma_seg, 'can', 'gcanbar', 'eca', {'m': 2, 'h': 1},
                                         calcium_current=True, with_h2=True),
    },
    'dend': {
        'nax':  make_mechanism_recorders(dend_seg, 'nax', 'gbar', 'ena', {'m': 3, 'h': 1}),
        'kdr':  make_mechanism_recorders(dend_seg, 'kdr', 'gkdrbar', 'ek', {'n': 1}),
        'kca':  make_mechanism_recorders(dend_seg, 'kca', 'gbar', 'ek', {'m': 3}),
        'kad':  make_mechanism_recorders(dend_seg, 'kad', 'gkabar', 'ek', {'n': 1, 'l': 1}),
        'cagk': make_mechanism_recorders(dend_seg, 'cagk', 'gbar', 'ek', {'o': 1}),
        'hd':   make_mechanism_recorders(dend_seg, 'hd', 'ghdbar', -30, {'l': 1}),
        'cat':  make_mechanism_recorders(dend_seg, 'cat', 'gcatbar', 'eca', {'m': 2, 'h': 1},
                                         calcium_current=True, with_h2=False),
        'cal':  make_mechanism_recorders(dend_seg, 'cal', 'gcalbar', 'eca', {'m': 2},
                                         calcium_current=True, with_h2=True),
        'can':  make_mechanism_recorders(dend_seg, 'can', 'gcanbar', 'eca', {'m': 2, 'h': 1},
                                         calcium_current=True, with_h2=True),
    }
}
try:
    mech_rec['dend']['nap'] = make_mechanism_recorders(dend_seg, 'nap', 'gnabar', 'ena', {'n': 3})
except:
    print('The cell does not have a persistent sodium current in its apical dendrite.')

#### Compute the presynaptic spike times

In [None]:
t0 = 500.
dt = 2000.

single_trial = False
with_somatic_current_injection = False

if with_somatic_current_injection:
    presyn_spike_times = [np.sort(t0 + n_spines * dt - np.arange(i) * dt) for i in range(n_spines+1, 0, -1)]
    presyn_spike_times = presyn_spike_times[:-1]
else:
    presyn_spike_times = [np.sort(t0 + (n_spines - 1) * dt - np.arange(i) * dt) for i in range(n_spines, 0, -1)]

if 'poisson_frequency' in config and False:
    F = config['poisson_frequency']
    for i in range(n_spines):
        ISI = - np.log(np.random.uniform(size=n_spines*2)) / F
        spks = np.cumsum(ISI)
        for j in range(n_spines - i):
            presyn_spike_times[j][i] += spks[j]
else:
    if 'spike_dt' in config:
        spike_dt = config['spike_dt']
    else:
        spike_dt = 0.3
    for i in range(n_spines):
        for j in range(len(presyn_spike_times[i])):
            presyn_spike_times[i][j] += i * spike_dt

if single_trial:
    presyn_spike_times = [np.array([spks[-1] - presyn_spike_times[0][-1] + t0]) for spks in presyn_spike_times]
    
for syn, spks in zip(synapses, presyn_spike_times):
    syn.set_presynaptic_spike_times(spks)
    
if with_somatic_current_injection:
    stim['soma'].amp = 0.06
    stim['soma'].dur = 100
    stim['soma'].delay = presyn_spike_times[0][-1]

In [None]:
presyn_spike_times

#### Run the simulation

In [None]:
h.cvode_active(1)
h.tstop = presyn_spike_times[0][-1] + dt
if with_somatic_current_injection:
    h.tstop += dt
h.run();

#### Get the data from the recorders

In [None]:
t = np.array(rec['t'])
Vspine = np.array([np.array(rec['Vspine-{}'.format(i)]) for i in range(n_spines)])
Vdend = np.array([np.array(rec['Vdend-{}'.format(i)]) for i in range(n_spines)])
Vsoma = np.array(rec['Vsoma'])

prop_V = [[np.array(v) for v in grp] for grp in prop_rec]
prop_INa = [[np.array(v) for v in grp] for grp in prop_rec_ina]
prop_IK = [[np.array(v) for v in grp] for grp in prop_rec_ik]
prop_ICa = [[np.array(v) for v in grp] for grp in prop_rec_ica]
prop_Cai = [[np.array(v) for v in grp] for grp in prop_rec_cai]

### spike times
spks = np.array(rec['spike_times'])
spike_times = [spks[(spks > spk) & (spks < spk + 200)] for spk in presyn_spike_times[0]]
N_spikes = np.array(list(map(len, spike_times)))
print(N_spikes)

1. the somatic voltage is in prop_V[-1][0]
1. the membrane voltage of the dendrite connected to the first spine is in prop_V[0][-1]

In [None]:
np.all(prop_V[0][-1] == Vdend[0])

In [None]:
np.all(prop_V[-1][0] == Vsoma)

In [None]:
currents = {}
for loc in mech_rec:
    currents[loc] = {}
    for key,value in mech_rec[loc].items():
        currents[loc][key] = compute_current(value)

In [None]:
current_names = 'nax', 'nap', 'kdr', 'kap', 'kmb', 'kca', 'cagk', 'cat', 'cal', 'can', 'hd'
n_currents = len(current_names)
fig,ax = plt.subplots(1, n_currents, figsize=(1.2 * n_currents, 3))
xlim = [13000, 13050]
idx = (t > xlim[0]) & (t < xlim[1])
cmap = plt.get_cmap('Paired')
for i,name in enumerate(current_names):
    if 'na' in name:
        jdx = 0
    elif 'k' in name:
        jdx = 1
    elif 'ca' in name:
        jdx = 2
    else:
        jdx = 3
    ax[i].plot(t[idx], currents['soma'][name][idx], color=cmap(jdx * 2), lw=1)
    if name in currents['dend']:
        ax[i].plot(t[idx], currents['dend'][name][idx], color=cmap(jdx * 2 + 1), lw=1)
    ax[i].set_title(name)

for a in ax:
    a.set_xticks([])
    a.set_yticks([])
    a.axis('off')

fig.tight_layout()

#### Save the traces containing spikes to an Excel spreadsheet

#### Plot the results

In [None]:
first = np.where(N_spikes > 0)[0][0]
window = [5,100]
t0 = presyn_spike_times[0][first]
idx, = np.where((t > t0 - window[0]) & (t < t0 + window[1]))
fig,ax = plt.subplots(1, 1, figsize=(8,3))
ax.plot(t[idx] - t0, Vspine[0][idx], color=[1,.5,1], linewidth=2, label='Spine')
ax.plot(t[idx] - t0, Vdend[0][idx], color=[.2,1,.2], linewidth=2, label='Dendrite')
ax.plot(t[idx] - t0, Vsoma[idx], 'k', linewidth=1, label='Soma')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Vm (mV)')
ax.legend(loc='upper right')
for side in ('right','top'):
    ax.spines[side].set_visible(False)
fig.tight_layout();

In [None]:
last = np.where(N_spikes > 0)[0][-1]
window = [5,200]
t0 = presyn_spike_times[0][last]
idx, = np.where((t > t0 - window[0]) & (t < t0 + window[1]))
fig,ax = plt.subplots(1, 1, figsize=(8,3))
ax.plot(t[idx] - t0, Vspine[0][idx], color=[1,.5,1], linewidth=2, label='Spine')
ax.plot(t[idx] - t0, Vdend[0][idx], color=[.2,1,.2], linewidth=2, label='Dendrite')
ax.plot(t[idx] - t0, Vsoma[idx], 'k', linewidth=1, label='Soma')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Vm (mV)')
ax.legend(loc='upper right')
for side in ('right','top'):
    ax.spines[side].set_visible(False)
fig.tight_layout();

### Plot the results
Here we see how membrane voltage, internal calcium concentration and ionic currents evolve as we move towards the soma. The three bottom panel display current densities on the left and actual current values on the right.

In [None]:
n_rec = len([0 for grp in prop_rec for v in grp])
colors = [[0,0,0], [1,0,0], [0,1,0], [0,0,1], [0,1,1], [1,0,1], [1,1,0], [1,.5,0]]
cmap = plt.get_cmap('hsv', len(prop_V))
fig = plt.figure(constrained_layout=True, figsize=(9,12))
gs = fig.add_gridspec(5, 2)
ax_V = fig.add_subplot(gs[0,:])
ax_Ca = fig.add_subplot(gs[1,:])
ax = [[fig.add_subplot(gs[i,j]) for j in (0,1)] for i in (2,3,4)]
for i in range(len(prop_V)):
    base_col = colors[i]
    if i == len(prop_V) - 1:
        lw = 2
    else:
        lw = 1
    for j in range(len(prop_INa[i])):
        col = np.min([[1,1,1], base_col + j * 0.2 * np.ones(3)], axis=0)
        ax_V.plot(t[idx] - t0, prop_V[i][j][idx], color=col, linewidth=lw)
        ax_Ca.plot(t[idx] - t0, prop_Cai[i][j][idx]*1e6, color=col, linewidth=lw)
        ax[0][0].plot(t[idx] - t0, prop_INa[i][j][idx], color=col, linewidth=lw)
        ax[1][0].plot(t[idx] - t0, prop_IK[i][j][idx], color=col, linewidth=lw)
        ax[2][0].plot(t[idx] - t0, prop_ICa[i][j][idx], color=col, linewidth=lw)
        ax[0][1].plot(t[idx] - t0, prop_INa[i][j][idx] * seg_area[i][j] * 1e-2, color=col, linewidth=lw)
        ax[1][1].plot(t[idx] - t0, prop_IK[i][j][idx] * seg_area[i][j] * 1e-2, color=col, linewidth=lw)
        ax[2][1].plot(t[idx] - t0, prop_ICa[i][j][idx] * seg_area[i][j] * 1e-2, color=col, linewidth=lw)
for side in ('right','top'):
    for i in range(3):
        for j in range(2):
            ax[i][j].spines[side].set_visible(False)
    ax_V.spines[side].set_visible(False)
    ax_Ca.spines[side].set_visible(False)
ax_Ca.set_xlabel('Time (ms)')
ax_V.set_ylabel('Vm (mV)')
ax_Ca.set_ylabel('Ca_i (nM)')
ax[0][0].set_ylabel('I_Na (mA/cm2)')
ax[1][0].set_ylabel('I_K (mA/cm2)')
ax[2][0].set_ylabel('I_Ca (mA/cm2)')
ax[0][1].set_ylabel('I_Na (nA)')
ax[1][1].set_ylabel('I_K (nA)')
ax[2][1].set_ylabel('I_Ca (nA)')
ax[2][0].set_xlabel('Time (ms)')
ax[2][1].set_xlabel('Time (ms)')
plt.tick_params(labelsize=8)

### Plot a currentscape
Algorithm taken from this paper:<br/>
Alonso, L.M. and Marder, E., 2019. Visualization of currents in neural models with similar behavior and different conductance densities. Elife, 8, p.e42722.

In [None]:
window = [15, 50]
idx = np.where(N_spikes > 0)[0][0]
t0 = spike_times[idx][0]
idx, = np.where((t > t0 - window[0]) & (t < spike_times[idx][-1] + window[1]))

In [None]:
current_density = True
loc = 'soma'
N_samples = idx.size
if loc == 'soma':
    inward = 'nax','nap','cal','cat','can'
else:
    inward = 'nax','cal','cat','can'
outward = 'kdr','kap','kmb','kca','cagk','hd','kad','kap'
inward_current = np.zeros(N_samples)
outward_current = np.zeros(N_samples)
area = {'soma': soma_seg.area(), 'dend': dend_seg.area()}
for name,I in currents[loc].items():
    if current_density:
        x = I[idx].copy() # [mA/cm2]
        x *= 1e1          # [pA/um2]
    else:
        x = I[idx].copy() * area[loc] * 1e+1 # [pA]
    if name in inward:
        x[x > 0] = 0
        inward_current += x
    elif name in outward:
        x[x < 0] = 0
        outward_current += x
normalized_inward, normalized_outward = {}, {}
for name,I in currents[loc].items():
    if current_density:
        x = I[idx].copy() # [mA/cm2]
        x *= 1e1          # [pA/um2]
    else:
        x = I[idx].copy() * area[loc] * 1e+1 # [pA]
    if name in inward:
        x[x > 0] = 0
        normalized_inward[name] = x / inward_current * 100
    elif name in outward:
        x[x < 0] = 0
        normalized_outward[name] = x / outward_current * 100
        
inward_percent = sorted([[k,v.mean()] for k,v in normalized_inward.items()], key=lambda x: x[1])
inward_order = [elem[0] for elem in inward_percent]
print(inward_order)
outward_percent = sorted([[k,v.mean()] for k,v in normalized_outward.items()], key=lambda x: x[1])
outward_order = [elem[0] for elem in outward_percent]
print(outward_order)

current_codes = {
    'hd': 0, 'cagk': 1, 'kdr': 2, 'kca': 3, 'kmb': 4, 'kap': 5, 'kad': 6,
    'cal': 7, 'can': 8, 'cat': 9, 'nap': 10, 'nax': 11
}
n_currents = len(current_codes)

In [None]:
hex2rgb_cmap = lambda i, colors: tuple([int(colors[i % len(colors)][j:j+2], 16) / 255 for j in range(1, 6, 2)])
my_cmap = ['#C30000', '#00CE00', '#2655C1',
           '#CA00FF', '#FFF600', '#26DFC1',
           '#FF7F00', '#A0A0A0', '#CE9944',
           '#417A44', '#76232F', '#F0C5F0']

In [None]:
yscale = 'log'

# fig,ax = plt.subplots(5, 1, figsize=(8,8), sharex=True)
matplotlib.rc('font', family='arial', size=6)
matplotlib.rc('axes', linewidth=0.7)
matplotlib.rc('ytick.major', size=2, width=0.7)
fig = plt.figure(figsize=(8 / 2.54, 9 / 2.54))
# height of the top axes
top_h = 0.3
# width of the two insets
inset_w = 0.05
rows, cols = 4, 1
xspace, yspace = 0.02, 0.02
xoffset, yoffset = [0.25, inset_w * 3], [0.02, 0.02 + top_h + yspace]
w = (1 - np.sum(xoffset) - (cols - 1) * xspace) / cols
h = (1 - np.sum(yoffset) - (rows - 1) * yspace) / rows
ax = [fig.add_axes([xoffset[0], 1 - 0.02 - top_h, w, top_h])]
for i in  range(1, rows+1):
    ax.append(fig.add_axes([xoffset[0], 1 - yoffset[1] - h * i - yspace * (i - 1), w, h]))

insets = [fig.add_axes([1 - inset_w * 1.2, 1 - yoffset[1] - h * i - yspace * (i - 1), inset_w, h]) \
          for i in range(2,4)]

ax[0].plot(t[idx], Vspine[0][idx], color=[1,.5,1], linewidth=1, label='Spine')
ax[0].plot(t[idx], Vdend[0][idx], color=[.2,1,.2], linewidth=1, label='Dendrite')
ax[0].plot(t[idx], Vsoma[idx], 'k', linewidth=1, label='Soma')
ax[0].legend(loc='upper right', frameon=False)
ax[0].set_yticks(np.r_[-80 : 60 : 20])
ax[0].set_ylabel('[mV]')

ax[1].fill_between(t[idx], outward_current, 1e-2 + np.zeros(outward_current.shape), color='k')
ax[1].set_yscale('log')
if current_density:
#     ax[1].set_ylabel(r'+[mA/cm$^2$]')
    ax[1].set_ylabel(r'+[pA/$\mu$m$^2$]')
    y_coord = 1
else:
    ax[1].set_ylabel('+[pA]')
    y_coord = 100
dt = 10
t_scale = t[idx[-1]] - dt * np.array([1.5, 0.5])
ax[1].plot(t_scale, y_coord + np.zeros(2), 'k', lw=2)
ax[1].text(np.sum(t_scale) / 2, y_coord * 0.7, f'{dt} ms', horizontalalignment='center', \
           verticalalignment='top')
if cell_type == 'thorny':
    ax[1].set_ylim([1e-2, 2e1])
    yticks = np.logspace(-2, 1, 4)
else:
    ax[1].set_ylim([1e-2, 2e1])
    yticks = np.logspace(-2, 1, 4)
ax[1].set_yticks(yticks)
ax[1].set_yticklabels(['{:g}'.format(yt) for yt in yticks])

# cmap = plt.get_cmap('Set3', n_currents + 1)
# cmap = lambda i: plotly_cmap(i, qualitative.Dark24)
cmap = lambda i: hex2rgb_cmap(i, my_cmap)
cumulative = np.zeros(N_samples)
cumulative_percent = 0
dt = np.diff(t[idx[[0,-1]]])[0]
ypos = np.logspace(np.log10(0.2), np.log10(90), len(outward_order))
for i,name in enumerate(outward_order):
    prev = cumulative.copy()
    cumulative += normalized_outward[name]
    col = cmap(current_codes[name])
    if yscale == 'log':
        prev[prev < 0.1] = 0.1
        curr = cumulative.copy()
        curr[curr < 0.1] = 0.1
    ax[2].fill_between(t[idx], curr, prev, color=col, label=name)
    ax[2].text(t[idx[0]] - dt/4, ypos[i], name, color=col, \
               horizontalalignment='right', verticalalignment='center')
    prev = cumulative_percent + np.zeros(2)
    cumulative_percent += outward_percent[i][1]
    insets[0].fill_between([0, 1], cumulative_percent + np.zeros(2), prev, color=col)
# ax[2].legend(loc='right', frameon=True, fontsize=8)
ax[2].set_yscale(yscale)
ax[2].set_ylabel('Outward %')

cumulative = np.zeros(N_samples)
cumulative_percent = 0
ypos = np.logspace(np.log10(0.2), np.log10(90), len(inward_order))
for i,name in enumerate(inward_order):
    prev = cumulative.copy()
    cumulative += normalized_inward[name]
    col = cmap(current_codes[name])
    if yscale == 'log':
        prev[prev < 0.1] = 0.1
        curr = cumulative.copy()
        curr[curr < 0.1] = 0.1
    ax[3].fill_between(t[idx], curr, prev, color=col, label=name)
    ax[3].text(t[idx[0]] - dt/4, ypos[i], name, color=col, \
               horizontalalignment='right', verticalalignment='center')
    prev = cumulative_percent + np.zeros(2)
    cumulative_percent += inward_percent[i][1]
    insets[1].fill_between([0, 1], cumulative_percent + np.zeros(2), prev, color=col)
# ax[3].legend(loc='right', frameon=True, fontsize=8)
ax[3].set_yscale(yscale)
ax[3].set_ylabel('Inward %')

ax[4].fill_between(t[idx], -inward_current, 1e-4 + np.zeros(inward_current.shape), color='k')
ax[4].set_yscale('log')
if current_density:
    ax[4].set_ylabel(r'-[pA/$\mu$m$^2$]')
else:
    ax[4].set_ylabel('-[pA]')
if cell_type == 'thorny':
    ax[4].set_ylim([1e-4, 1e1])
    yticks = np.logspace(-4, 1, 6)
else:
    ax[4].set_ylim([1e-4, 1e1])
    yticks = np.logspace(-4, 1, 6)
ax[4].set_yticks(yticks)
ax[4].set_yticklabels(['{:g}'.format(yt) for yt in yticks])
ax[4].invert_yaxis()

if yscale == 'log':
    ylim = [1e-1, 100]
    yticks = np.logspace(-1, 2, 4)
else:
    ylim = [0, 100]
    yticks = np.linspace(0, 100, 11)

for a in ax[2:4]:
    a.set_ylim(ylim)
    a.set_yticks(yticks)
    a.set_yticklabels(['{:g}'.format(yt) for yt in yticks])

for a in insets:
    a.set_xlim([0, 1])
    a.set_xticks([])
    a.set_ylim(ylim)
    a.set_yscale('linear')
    a.set_yticks(np.linspace(0, 100, 5))
#     a.set_yticklabels([])

ax[3].invert_yaxis()
insets[1].invert_yaxis()

for a in ax:
    a.set_xticks([])
    a.set_xlim(t[idx[[0,-1]]])
    for side in 'right','top','bottom':
        a.spines[side].set_visible(False)
for a in ax[0], ax[1], ax[-1]:
    a.grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])

for a in ax:
    a.minorticks_off()

fig.savefig(f'synaptic_inputs_active_currentscape_{cell_type}_{individual}_{loc}_{yscale}.pdf')

#### Save the data