**Note.** *The following notebook contains code in addition to text and figures. By default, the code has been hidden. You can click the icon that looks like an eye in the toolbar above to show the code. Also note that the code runs when you first open the notebook, so it may take a few seconds to a minute before the figures appear.*

# Level dependence

In this notebook, we examine the level dependence of the CV and firing rate of chopper cells. We compare experimental data with the results that can be achieved with the model introduced in the notebooks on the [Basic Model](basic_model.ipynb) and [Behaviour Maps](maps.ipynb).

In [None]:
%%html
<!-- hack to improve styling of ipywidgets sliders -->
<style type="text/css">
.widget-hbox .widget-label {
    min-width: 35ex;
    max-width: 35ex;
}
.widget-hslider {
    width: 100%;
}
.widget-hprogress {
    width: 100%;
}

</style>

In [None]:
# Imports etc.
%matplotlib inline
from brian2 import *
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy.random as np_rand
from functools import partial
from collections import OrderedDict
import ipywidgets as ipw

import warnings
warnings.filterwarnings("ignore")

defaultclock.dt = 0.05*ms

# Load experimental data summary stats
exp_cv20 = loadtxt('expdata_summary/allcv20.txt')
exp_cv50 = loadtxt('expdata_summary/allcv50.txt')
exp_fr20 = loadtxt('expdata_summary/allfr20.txt')
exp_fr50 = loadtxt('expdata_summary/allfr50.txt')
crossing = ((exp_cv20>0.35)&(exp_cv50<0.35))|((exp_cv20<0.35)&(exp_cv50>0.35))
exp_frdiff = exp_fr50-exp_fr20
exp_cvdiff = exp_cv50-exp_cv20

# Plotting functions

# Utility function to create axes on the top and right
# We use this below to create histograms on the sides of a plot
def get_sidehist_axes(ax=None):
    if ax is None:
        ax = gca()
    divider = make_axes_locatable(ax)
    ax_top = divider.append_axes("top", size=0.5, pad=0.0, sharex=ax)
    setp(ax_top.get_xticklabels(), visible=False)
    setp(ax_top.get_xticklines(), visible=False)
    setp(ax_top.get_yticklabels(), visible=False)
    setp(ax_top.get_yticklines(), visible=False)
    ax_top.set_frame_on(False)
    ax_right = divider.append_axes("right", size=0.5, pad=0.0, sharey=ax)
    setp(ax_right.get_xticklabels(), visible=False)
    setp(ax_right.get_xticklines(), visible=False)
    setp(ax_right.get_yticklabels(), visible=False)
    setp(ax_right.get_yticklines(), visible=False)
    ax_right.set_frame_on(False)
    sca(ax)
    return ax_top, ax_right

# Plot CV and firing rate, pass varx=fr20, vary=cv20 or fr50, cv50 to
# plot panels A and B
def pointplot(fr20, fr50, cv20, cv50, varx, vary, dolabel=False, col=None, ls=None,
              sidehists=True, muted=False):
    if muted:
        col = (0.5, 0.5, 1)
    I = (cv20<0.35)*(cv50<0.35)
    plot(varx[I], vary[I], ls if ls else '^',
         ms=6, mec='none', c=col if col else 'b', 
         label='Sustained' if dolabel else None)
    if muted:
        col = (0.5, 1, 0.5)
    I = (cv20>0.35)*(cv50>0.35)
    plot(varx[I], vary[I], ls if ls else 's',
         ms=6, mec='none', c=col if col else 'g',
         label='Transient' if dolabel else None)
    if muted:
        col = (1, 0.5, 0.5)
    I = ((cv20<0.35)*(cv50>0.35))+((cv20>0.35)*(cv50<0.35))
    plot(varx[I], vary[I], ls if ls else 'o',
         ms=6, mec='none', c=col if col else 'r',
         label='Mixed' if dolabel else None)
    if sidehists:
        ax_top, ax_right = get_sidehist_axes()
        ax_top.hist(varx, 20, facecolor=(0.0,)*3, ec='none')
        ax_right.hist(vary, 20, facecolor=(0.0,)*3, ec='none', orientation='horizontal')

# Plot panel C
def diffplot(fr20, fr50, cv20, cv50, dolabel=True, muted=False):
    cn_rate_diffs = fr50-fr20
    cv_diffs = cv50-cv20
    pointplot(fr20, fr50, cv20, cv50, cn_rate_diffs, cv_diffs, dolabel=dolabel, muted=muted)
    xlabel('Firing rate difference (sp/s)')
    ylabel('CV difference')
    axhline(0, ls='-', c='k')
    axvline(0, ls='-', c='k')
    xlim(-150, 150)
    ylim(-0.25, 0.25)

# This and the next function plot panel D
def arrowplot(allfr20, allfr50, allcv20, allcv50, arrowlength=1.0, muted=False):
    for fr20, fr50, cv20, cv50 in zip(allfr20, allfr50, allcv20, allcv50):
        #arrow(fr20, cv20, 0.1*(fr50-fr20), 0.1*(cv50-cv20))#, head_width=0.05, head_length=0.1, fc='k', ec='k')
        if cv20<0.35 and cv50<0.35:
            c = 'b'
            if muted:
                c = (0.5, 0.5, 1)
        elif cv20>0.35 and cv50>0.35:
            c = 'g'
            if muted:
                c = (0.5, 1, 0.5)
        else:
            c = 'r'
            if muted:
                c = (1, 0.5, 0.5)
        annotate('', xytext=(fr20, cv20),
                 xy=(fr20+arrowlength*(fr50-fr20), cv20+arrowlength*(cv50-cv20)),
                 arrowprops=dict(arrowstyle='->', ec=c))
    xlim(min(amin(allfr20), amin(allfr50)), max(amax(allfr20), amax(allfr50)))
    ylim(min(amin(allcv20), amin(allcv50)), max(amax(allcv20), amax(allcv50)))
    axhline(0.35, ls='--', c='k')
    xlabel('Firing rate(sp/s)')
    ylabel('CV')
    xlim(50, 500)
    ylim(0.1, 0.7)

def diff_arrow_plot(allfr20, allfr50, allcv20, allcv50, dolabel=True, arrowlength=1.0, muted=False):
    subplot(121)
    diffplot(allfr20, allfr50, allcv20, allcv50, dolabel=dolabel, muted=muted)        
    subplot(122)
    arrowplot(allfr20, allfr50, allcv20, allcv50, arrowlength=arrowlength, muted=muted)
    ax_top, ax_right = get_sidehist_axes()
    if dolabel:
        for i in xrange(2):
            figtext(0.01+i/2.0, 0.95, chr(ord('A')+i), size='x-large')

# Plot all panels
def cvfr_level_dependence_plot(allfr20, allfr50, allcv20, allcv50,
                               muted=False, axes=None,
                               highlighted=None):
    if axes is None:
        figure(figsize=(7, 6))
        ax20 = subplot(221)
        ax50 = subplot(222)
        axdiff = subplot(223)
        axarrow = subplot(224)
    else:
        ax20, ax50, axdiff, axarrow = axes
    if highlighted is not None:
        hfr20, hfr50, hcv20, hcv50 = highlighted
    sca(ax20)
    pointplot(allfr20, allfr50, allcv20, allcv50, allfr20, allcv20, muted=muted)
    if highlighted is not None:
        plot(hfr20, hcv20, '*k', ms=12)
    xlabel('Firing rate (sp/s)')
    ylabel('CV')
    title('20 dB re threshold', y=0.87)
    ylim(0.1, 0.7)
    xlim(0, 500)
    axhline(0.35, ls='--', c='k')

    sca(ax50)
    pointplot(allfr20, allfr50, allcv20, allcv50, allfr50, allcv50, muted=muted)
    if highlighted is not None:
        plot(hfr50, hcv50, '*k', ms=12)
    xlabel('Firing rate (sp/s)')
    ylabel('CV')
    title('50 dB re threshold', y=0.87)
    ylim(0.1, 0.7)
    xlim(0, 500)
    axhline(0.35, ls='--', c='k')

    sca(axdiff)
    diffplot(allfr20, allfr50, allcv20, allcv50, muted=muted)
    if highlighted is not None:
        plot(hfr50-hfr20, hcv50-hcv20, '*k', ms=12, label="Model")
    legend(loc='lower left', fontsize='small')

    sca(axarrow)
    arrowplot(allfr20, allfr50, allcv20, allcv50, muted=muted)
    if highlighted is not None:
        arrowlength = 1.0
        annotate('', xytext=(hfr20, hcv20),
                 xy=(hfr20+arrowlength*(hfr50-hfr20), hcv20+arrowlength*(hcv50-hcv20)),
                 arrowprops=dict(arrowstyle='->', ec='k', lw=2))
    ax_top, ax_right = get_sidehist_axes()
    xlabel('Firing rate (sp/s)')
    ylabel('CV')

    if axes is None:
        for i in xrange(2):
            for j in xrange(2):
                figtext(0.01+i/2.0, 0.95-j/2.0, chr(ord('A')+i+2*j), size='x-large')

        tight_layout()

We start with a plot of the experimental data (a new analysis of data recorded in the lab of [Ian Winter](http://www.neuroscience.cam.ac.uk/directory/profile.php?imw1001) over a number of years). This shows that the CV and firing rate of chopper cells changes at different sound levels. The upper two panels show the distribution of these quantities at 20 and 50 dB sound levels, and the lower plots show the differences as points or arrows. The points in blue are those whose CV is less than 0.35 at both levels (so unambiguous sustained choppers), in green if the CV is higher than 0.35 at both levels (unambiguous transient), or in red if the CV crosses this boundary between the two levels.

In [None]:
cvfr_level_dependence_plot(exp_fr20, exp_fr50, exp_cv20, exp_cv50)

In [None]:
thenet = None
def get_steady_state_data(
        repeats=1000,
        duration=250*ms,
        skip=50*ms,
        mu=4.0,
        sigma=0.1,
        tau=5*ms,
        refractory=0*ms,
        ):
    global thenet
    if thenet is None or len(thenet['G'])!=repeats:
        eqs = '''
        dv/dt = (mu-v)/tau+sigma*tau**-0.5*xi : 1 (unless refractory)
        refrac : second
        '''
        G = NeuronGroup(repeats, eqs, threshold='v>1', reset='v=0', refractory='refrac',
                        name='G', method='euler')
        M = SpikeMonitor(G, name='M')
        thenet = Network(G, M)
        thenet.store()
    else:
        G = thenet['G']
        M = thenet['M']
    thenet.restore()
    G.refrac = refractory
    G.not_refractory = True
    G.lastspike = -inf*second
    G.v = 0
    ns = {'mu': mu, 'tau': tau, 'sigma': sigma}
    M.active = False
    thenet.run(skip, namespace=ns)
    M.active = True
    thenet.run(duration-skip, namespace=ns)    
    trains = M.spike_trains()
    dtrains = [diff(train) for train in trains.values() if len(train)>1]
    if len(dtrains):
        isi = hstack(dtrains)*second
    else:
        isi = array([])
    if len(isi)>1:
        cv = std(isi)/mean(isi)
    else:
        cv = nan
    rate = len(M.t)/(repeats*(duration-skip))
    return cv, rate

In the following interactive figure, you can see the effect of having a different excitatory and inhibitory firing rate at 20 and 50 dB for different model parameters. In general, higher input firing rates will decrease the CV and increase the output firing rate. Higher inhibition will increase the CV and decrease the firing rate. The model result is shown with a black point or arrow, and the experimental results as above but in a lighter shade.

In [None]:
def compare_model_to_data(rho20_Hz=150, rho50_Hz=200, N=40, alpha20=0.0, alpha50=0.5,
                          mu=2.0, tau_ms=6.0, refractory_ms=0.1):
    # Parameters
    tau = tau_ms*ms
    refractory = refractory_ms*ms
    rho20 = rho20_Hz*Hz
    rho50 = rho50_Hz*Hz
    # Compute synaptic weight
    weight = mu/(N*tau*0.5*(rho20*(1-alpha20)+rho50*(1-alpha50)))
    # Get model chopper cell results
    def f(weight, N, anf_rate_exc, anf_rate_inh):
        tau_exc = tau_inh = tau
        mu_exc = weight*N*tau*anf_rate_exc
        mu_inh = weight*N*tau*anf_rate_inh
        sigma2_exc = weight*mu_exc
        sigma2_inh = weight*mu_inh
        return get_steady_state_data(mu=mu_exc-mu_inh,
                                     sigma=sqrt(sigma2_exc+sigma2_inh))
    cv20, fr20 = f(weight, N, rho20, rho20*alpha20)
    cv50, fr50 = f(weight, N, rho50, rho50*alpha50)
    cvfr_level_dependence_plot(exp_fr20, exp_fr50, exp_cv20, exp_cv50, muted=True,
                               highlighted=[fr20, fr50, cv20, cv50])
 
widgets = OrderedDict([ # using an ordered dict doesn't work, not sure if there is a way?
        ('N', ipw.IntSlider(min=1, max=100, step=1, value=40,
                continuous_update=False,
                description=r"Number of AN fibres $N$")),
        ('mu', ipw.FloatSlider(min=0, max=5, step=0.01, value=2.0,
                continuous_update=False,
                description=r"Mean current $\mu$")),
        ('tau_ms', ipw.FloatSlider(min=0.1, max=15, step=0.1, value=6,
                continuous_update=False,
                description=r"Membrane time constant $\tau$ (ms)")),
        ('refractory_ms', ipw.FloatSlider(min=0, max=5, step=0.1, value=0.1,
                continuous_update=False,
                description=r"Refractory period $t_\mathrm{ref}$ (ms)")),
        ('rho20_Hz', ipw.FloatSlider(min=0, max=500, step=10, value=150,
                continuous_update=False,
                description=r"AN firing rate at 20 dB $\rho_{20}$ (Hz)")),
        ('rho50_Hz', ipw.FloatSlider(min=0, max=500, step=10, value=200,
                continuous_update=False,
                description=r"AN firing rate at 50 dB $\rho_{50}$ (Hz)")),
        ('alpha20', ipw.FloatSlider(min=0, max=1, step=0.01, value=0.0,
                continuous_update=False,
                description=r"Inhibitory fraction at 20 dB $\alpha_{20}$")),
        ('alpha50', ipw.FloatSlider(min=0, max=1, step=0.01, value=0.4,
                continuous_update=False,
                description=r"Inhibitory fraction at 50 dB $\alpha_{50}$")),
    ])

ipw.interact(compare_model_to_data, **widgets);

Finally, in the next figure (which will appear after a minute or so of computation time), you can see a model of the experimental data in the first figure above. To get this model, we have chosen a particular distribution of model parameters, which are the same at 20 and 50 dB levels except for the excitatory and inhibitory firing rates. The excitatory firing rates always increase, whereas the inhibitory rates can increase or decrease from 20 dB to 50 dB.

This figure is not interactive as it takes a little while to compute and because tweaking the parameters of this distribution is not very interesting. If you do want to try out other parameter distributions and see what they look like, show the code (by clicking the eye icon in the toolbox above, and modify the function ``parameter_distribution`` below, and the ``min_rate`` and ``max_rate`` parameters in the line starting ``params`` near the bottom. Then re-run the cell by clicking it and pressing Ctrl+Enter.

In [None]:
def sigmoid(x, k):
    return 1./(1+exp(k*(1-2*x)))

def parameter_distribution():
    rechoose = True
    while rechoose:
        mu = randn()*.4+2
        k = 6
        inh_max = 0.65
        inh_1 = sigmoid(rand(), k)*inh_max
        inh_2 = sigmoid(rand(), k)*inh_max
        anf_rate_1 = randn()*25*Hz+250*Hz
        anf_rate_2 = randn()*25*Hz+250*Hz
        num_anf = np_rand.randint(30, 60+1)
        tau = exp(np_rand.uniform(log(5), log(15.0)))*ms
        refractory = exp(np_rand.uniform(log(0.1), log(5.0)))*ms

        rechoose = False
        if mu<1 or mu>4:
            rechoose = True
        if anf_rate_1<150*Hz or anf_rate_2<150*Hz or anf_rate_1>450*Hz or anf_rate_2>450*Hz:
            rechoose = True
        if anf_rate_2>anf_rate_1+75*Hz:
            rechoose = True
        if anf_rate_2<anf_rate_1:
            rechoose = True
        if inh_2-inh_1>0.25:
            rechoose = True
        if inh_2-inh_1<-0.1:
            rechoose = True

    return (mu, num_anf, anf_rate_1, anf_rate_2, inh_1, inh_2,
            dict(tau=tau, refractory=refractory))

def predict_level_dependence(
          N, dist,
          min_mean_rate=0*Hz, max_mean_rate=inf*Hz,
          min_rate=0*Hz, max_rate=inf*Hz,
          repeats=100, duration=100*ms, skip=15*ms,
          **params):
    ssd = partial(get_steady_state_data, repeats=repeats, duration=duration, skip=skip, **params)
    def f(weight, num_anf, anf_rate_exc, anf_rate_inh):
        tau_exc = tau_inh = tau
        weight_exc = weight
        weight_inh = inh_skew*weight
        num_anf_exc = num_anf
        num_anf_inh = num_anf/inh_skew
        mu_exc = weight_exc*num_anf_exc*tau_exc*anf_rate_exc
        mu_inh = weight_inh*num_anf_inh*tau_inh*anf_rate_inh
        sigma2_exc = weight_exc*mu_exc
        sigma2_inh = weight_inh*mu_inh
        return ssd(mu=mu_exc-mu_inh, sigma=sqrt(sigma2_exc+sigma2_inh))
    all_cv_1 = []
    all_cv_2 = []
    all_cn_rate_1 = []
    all_cn_rate_2 = []
    while len(all_cv_1)<N:
        p = dist()
        if len(p)==7:
            additional_params = p[6]
            params.update(**additional_params)
        tau = params.get('tau', 6*ms)
        inh_skew = params.get('inh_skew', 1.0)
        mu_base, num_anf, anf_rate_1, anf_rate_2, inh_1, inh_2 = p[:6]
        weight = mu_base/(num_anf*tau*0.5*(anf_rate_1*(1-inh_1)+anf_rate_2*(1-inh_2)))
        cv_1, cn_rate_1 = f(weight, num_anf, anf_rate_1, anf_rate_1*inh_1)
        cv_2, cn_rate_2 = f(weight, num_anf, anf_rate_2, anf_rate_2*inh_2)
        mean_rate = 0.5*(cn_rate_1+cn_rate_2)
        if mean_rate<min_mean_rate or mean_rate>max_mean_rate:
            continue
        if cn_rate_1<min_rate or cn_rate_2<min_rate or cn_rate_1>max_rate or cn_rate_2>max_rate:
            continue
        all_cv_1.append(cv_1)
        all_cv_2.append(cv_2)
        all_cn_rate_1.append(cn_rate_1)
        all_cn_rate_2.append(cn_rate_2)
    return array(all_cv_1), array(all_cv_2), array(all_cn_rate_1), array(all_cn_rate_2)

params = [dict(N=86, dist=parameter_distribution, min_rate=100*Hz, max_rate=450*Hz)]
cv1, cv2, rate1, rate2 = map(hstack, zip(*[predict_level_dependence(**p) for p in params]))
cvfr_level_dependence_plot(rate1, rate2, cv1, cv2)