# Cell types figure (unfinished)

In [None]:
%matplotlib notebook
from brian2 import *
from model_explorer_jupyter import *
import ipywidgets as ipw
from collections import OrderedDict
from scipy.interpolate import interp1d
from matplotlib import cm
from matplotlib.gridspec import GridSpecFromSubplotSpec
import joblib
from scipy.ndimage.interpolation import zoom
from scipy.ndimage.filters import gaussian_filter

BrianLogger.suppress_name('resolution_conflict')

def normed(X, *args):
    m = max(amax(abs(Y)) for Y in (X,)+args)
    return X/m

mem = joblib.Memory(location='.', bytes_limit=10*1024**3, verbose=0) # 10 GB max cache

Raw data we want to model

In [None]:
dietz_fm = array([4, 8, 16, 32, 64])*Hz
dietz_phase = array([37, 40, 62, 83, 115])*pi/180
dietz_phase_std = array([46, 29, 29, 31, 37])*pi/180

Basic model definitions

In [None]:
@mem.cache
def simple_model(N, params):
    min_tauihc = 0.1*ms
    eqs = '''
    carrier = clip(cos(2*pi*fc*t), 0, Inf) : 1
    A_raw = (carrier*gain*0.5*(1-cos(2*pi*fm*t)))**gamma : 1
    dA_filt/dt = (A_raw-A)/(int(tauihc<min_tauihc)*1*second+tauihc) : 1
    A = A_raw*int(tauihc<min_tauihc)+A_filt*int(tauihc>=min_tauihc) : 1
    dQ/dt = -k*Q*A+R*(1-Q) : 1
    AQ = A*Q : 1
    dAe/dt = (AQ-Ae)/taue : 1
    dAi/dt = (AQ-Ai)/taui : 1
    out = clip(Ae-beta*Ai, 0, Inf) : 1
    gain = 10**(level/20.) : 1
    R = (1-alpha)/taua : Hz
    k = alpha/taua : Hz
    fc = fc_Hz*Hz : Hz
    fc_Hz : 1
    fm : Hz
    tauihc = tauihc_ms*ms : second
    taue = taue_ms*ms : second
    taui = taui_ms*ms : second
    taua = taua_ms*ms : second
    tauihc_ms : 1
    taue_ms : 1
    taui_ms : 1
    taua_ms : 1
    alpha : 1
    beta : 1
    gamma : 1
    level : 1
    '''
    G = NeuronGroup(N, eqs, method='euler', dt=0.1*ms)
    G.set_states(params)
    G.tauihc_ms['tauihc_ms<min_tauihc/ms'] = 0
    G.Q = 1
    M = StateMonitor(G, 'out', record=True)
    net = Network(G)
    net.run(.25*second)
    net.add(M)
    net.run(.25*second)
    return M.t[:], M.out[:]

def extract_peak_phase(N, t, out, error_func, weighted, interpolate_bmf=False):
    out = reshape(out, (N, len(dietz_fm), len(t)))
    fm = dietz_fm
    n = array(around(0.25*second*fm), dtype=int)
    idx = (t[newaxis, newaxis, :]<(n/fm)[newaxis, :, newaxis])+zeros(out.shape, dtype=bool)
    out[idx] = 0
    if weighted:
        phase = (2*pi*fm[newaxis, :, newaxis]*t[newaxis, newaxis, :]) % (2*pi)
        peak_phase = (angle(sum(out*exp(1j*phase), axis=2))+2*pi)%(2*pi)
    else:
        peak = t[argmax(out, axis=2)] # shape (N, n_fm)
        peak_phase = (peak*2*pi*fm[newaxis, :]) % (2*pi) # shape (N, n_fm)
    peak_fr = amax(out, axis=2) # shape (N, n_fm)
    norm_peak_fr = peak_fr/amax(peak_fr, axis=1)[:, newaxis]
    mse = error_func(dietz_phase[newaxis, :], peak_phase) # sum over fm, mse has shape N
    mse_norm = (mse-amin(mse))/(amax(mse)-amin(mse))
    bmf = asarray(dietz_fm)[argmax(norm_peak_fr, axis=1)]
    moddepth = 1-amin(norm_peak_fr, axis=1)
    # interpolated bmf
    if interpolate_bmf:
        fm_interp = linspace(4, 64, 100)
        for cx in xrange(N):
            cur_fr = norm_peak_fr[cx, :]
            fr_interp_func = interp1d(dietz_fm, cur_fr, kind='quadratic')
            bmf[cx] = fm_interp[argmax(fr_interp_func(fm_interp))]
    return peak_phase, peak_fr, norm_peak_fr, mse, mse_norm, bmf, moddepth

Error functions

In [None]:
def rmse(x, y, axis=1):
    return sqrt(mean((x-y)**2, axis=axis))

def maxnorm(x, y, axis=1):
    return amax(abs(x-y), axis=axis)

error_functions = {
    'RMS error': rmse,
    'Max error': maxnorm,
    }

Parameter names

In [None]:
latex_parameter_names = dict(
    taue_ms=r"$\tau_e$ (ms)",
    taui_ms=r"$\tau_i$ (ms)",
    taua_ms=r"$\tau_a$ (ms)",
    alpha=r"$\alpha$",
    beta=r"$\beta$",
    gamma=r"$\gamma$",
    level=r"$L$ (dB)",
    )