# Maps of simple model

TODO:
* chunking simulation/analysis to reduce memory load
* sampling/smoothing on population/map plot

## Common code / data

In [1]:
%matplotlib notebook
from brian2 import *
from model_explorer_jupyter import *
import ipywidgets as ipw
from collections import OrderedDict
from scipy.interpolate import interp1d
from matplotlib import cm
from sklearn.manifold import TSNE, LocallyLinearEmbedding, Isomap, SpectralEmbedding, MDS
from sklearn.decomposition import PCA
import joblib
from scipy.ndimage.interpolation import zoom
from scipy.ndimage.filters import gaussian_filter

BrianLogger.suppress_name('resolution_conflict')

def normed(X, *args):
    m = max(amax(abs(Y)) for Y in (X,)+args)
    return X/m

progress_slider, update_progress = brian2_progress_reporter()

mem = joblib.Memory(location='.', bytes_limit=10*1024**3, verbose=0) # 10 GB max cache

Raw data we want to model

In [2]:
dietz_fm = array([4, 8, 16, 32, 64])*Hz
dietz_phase = array([37, 40, 62, 83, 115])*pi/180

## Definition of basic model

In [3]:
@mem.cache
def simple_model(N, params):
    min_tauihc = 0.1*ms
    eqs = '''
    carrier = clip(cos(2*pi*fc*t), 0, Inf) : 1
    A_raw = (carrier*gain*0.5*(1-cos(2*pi*fm*t)))**gamma : 1
    dA_filt/dt = (A_raw-A)/(int(tauihc<min_tauihc)*1*second+tauihc) : 1
    A = A_raw*int(tauihc<min_tauihc)+A_filt*int(tauihc>=min_tauihc) : 1
    dQ/dt = -k*Q*A+R*(1-Q) : 1
    AQ = A*Q : 1
    dAe/dt = (AQ-Ae)/taue : 1
    dAi/dt = (AQ-Ai)/taui : 1
    out = clip(Ae-beta*Ai, 0, Inf) : 1
    gain = 10**(level/20.) : 1
    R = (1-alpha)/taua : Hz
    k = alpha/taua : Hz
    fc = fc_Hz*Hz : Hz
    fc_Hz : 1
    fm : Hz
    tauihc = tauihc_ms*ms : second
    taue = taue_ms*ms : second
    taui = taui_ms*ms : second
    taua = taua_ms*ms : second
    tauihc_ms : 1
    taue_ms : 1
    taui_ms : 1
    taua_ms : 1
    alpha : 1
    beta : 1
    gamma : 1
    level : 1
    '''
    G = NeuronGroup(N, eqs, method='euler', dt=0.1*ms)
    G.set_states(params)
    G.tauihc_ms['tauihc_ms<min_tauihc/ms'] = 0
    G.Q = 1
    M = StateMonitor(G, 'out', record=True)
    net = Network(G)
    net.run(.25*second, report=update_progress, report_period=1*second)
    net.add(M)
    net.run(.25*second, report=update_progress, report_period=1*second)
    return M.t[:], M.out[:]

def extract_peak_phase(N, t, out, error_func, weighted, interpolate_bmf=False):
    out = reshape(out, (N, len(dietz_fm), len(t)))
    fm = dietz_fm
    n = array(around(0.25*second*fm), dtype=int)
    idx = (t[newaxis, newaxis, :]<(n/fm)[newaxis, :, newaxis])+zeros(out.shape, dtype=bool)
    out[idx] = 0
    if weighted:
        phase = (2*pi*fm[newaxis, :, newaxis]*t[newaxis, newaxis, :]) % (2*pi)
        peak_phase = (angle(sum(out*exp(1j*phase), axis=2))+2*pi)%(2*pi)
    else:
        peak = t[argmax(out, axis=2)] # shape (N, n_fm)
        peak_phase = (peak*2*pi*fm[newaxis, :]) % (2*pi) # shape (N, n_fm)
    peak_fr = amax(out, axis=2) # shape (N, n_fm)
    norm_peak_fr = peak_fr/amax(peak_fr, axis=1)[:, newaxis]
    mse = error_func(dietz_phase[newaxis, :], peak_phase) # sum over fm, mse has shape N
    mse_norm = (mse-amin(mse))/(amax(mse)-amin(mse))
    bmf = asarray(dietz_fm)[argmax(norm_peak_fr, axis=1)]
    moddepth = 1-amin(norm_peak_fr, axis=1)
    # interpolated bmf
    if interpolate_bmf:
        fm_interp = linspace(4, 64, 100)
        for cx in xrange(N):
            cur_fr = norm_peak_fr[cx, :]
            fr_interp_func = interp1d(dietz_fm, cur_fr, kind='quadratic')
            bmf[cx] = fm_interp[argmax(fr_interp_func(fm_interp))]
    return peak_phase, peak_fr, norm_peak_fr, mse, mse_norm, bmf, moddepth
    

### Specifications of parameters

In [4]:
parameter_specs = [
    dict(name='fc_Hz',
         description=r"Carrier frequency (0=env only) $f_c$ (Hz)",
         min=0, max=2000, step=100, value=0),
    dict(name='tauihc_ms',
         description=r"Inner hair cell time constant (<0.1=off) $\tau_{ihc}$ (ms)",
         min=0, max=10, step=0.1, value=0),
    dict(name='taue_ms',
         description=r"Excitatory filtering time constant $\tau_e$ (ms)",
         min=0.1, max=10, step=0.1, value=0.1),
    dict(name='taui_ms',
         description=r"Inhibitory filtering time constant $\tau_i$ (ms)",
         min=0.1, max=10, step=0.1, value=0.5),
    dict(name='taua_ms',
         description=r"Adaptation time constant $\tau_a$ (ms)",
         min=0.1, max=10, step=0.1, value=5),
    dict(name='alpha',
         description=r"Adaptation strength $\alpha$",
         min=0, max=0.99, step=0.01, value=0.8),
    dict(name='beta',
         description=r"Inhibition strength $\beta$",
         min=0, max=2, step=0.01, value=1.0),
    dict(name='gamma',
         description=r"Compression power $\gamma$",
         min=0.1, max=1, step=0.01, value=1.0),
    dict(name='level',
         description=r"Relative sound level $L$ (dB)",
         min=-90, max=90, step=5, value=0),
    ]

### Definition of error functions

In [5]:
def rmse(x, y, axis=1):
    return sqrt(mean((x-y)**2, axis=axis))

def maxnorm(x, y, axis=1):
    return amax(abs(x-y), axis=axis)

error_functions = {
    'RMS error': rmse,
    'Max error': maxnorm,
    }

### Definition of dimensionality reduction methods

In [6]:
dimensionality_reduction_methods = {
    'None': None,
    't-SNE': TSNE(n_components=2),
    'PCA': PCA(n_components=2),
    'Isomap': Isomap(n_components=2),
    'Locally linear embedding': LocallyLinearEmbedding(n_components=2),
    'Spectral embedding': SpectralEmbedding(n_components=2),
    'Multidimensional scaling': MDS(n_components=2),
    }

## Plot types

### 2D map

In [7]:
def plot_map2d_mse_mtf(selected_axes, **kwds):
    global curfig
    # Set up ranges of variables, and generate arguments to pass to model function
    error_func_name = kwds.pop('error_func')
    error_func = error_functions[error_func_name]
    interpolate_bmf = kwds.pop('interpolate_bmf')
    detail_settings = dict(Low=10, Medium=40, High=100)
    M = detail_settings[kwds.pop('detail')]
    weighted = kwds.pop('weighted')
    axis_ranges = dict((k, linspace(*(v+(M,)))) for k, v in kwds.items() if k in selected_axes)
    axis_ranges['fm'] = dietz_fm
    array_kwds = meshed_arguments(selected_axes+('fm',), kwds, axis_ranges)
    vx, vy = selected_axes
    shape = array_kwds[vx].shape
    N = array_kwds[vx].size
    array_kwds[vx].shape = N
    array_kwds[vy].shape = N
    array_kwds['fm'].shape = N
    # Run the model
    t, out = simple_model(N, array_kwds)
    (all_peak_phase, all_peak_fr, all_norm_peak_fr,
     mse, mse_norm, bmf, moddepth) = extract_peak_phase(M*M, t, out, error_func, weighted,
                                                        interpolate_bmf=interpolate_bmf)    
    # Analyse the data
    peak_phase = all_peak_phase.reshape((M, M, -1))
    norm_peak_fr = all_norm_peak_fr.reshape((M, M, -1))
    bmf.shape = moddepth.shape = mse.shape = mse_norm.shape = (M, M)
    # Properties of lowest MSE value
    idx_best_y, idx_best_x = unravel_index(argmin(mse), mse.shape)
    xbest = axis_ranges[vx][idx_best_x]
    ybest = axis_ranges[vy][idx_best_y]
    best_peak_phase = peak_phase[idx_best_y, idx_best_x, :]
    best_norm_peak_fr = norm_peak_fr[idx_best_y, idx_best_x, :]
    print 'Best: {vx} = {xbest}, {vy} = {ybest}'.format(vx=vx, vy=vy, xbest=xbest, ybest=ybest)
    # Plot the data
    extent = (kwds[vx]+kwds[vy])
    def labelit(titletext):
        plot([xbest], [ybest], '+w')
        title(titletext)
        xlabel(sliders[vx].description)
        ylabel(sliders[vy].description)
        cb = colorbar()
        cb.set_label(titletext, rotation=270, labelpad=20)

    curfig = figure(dpi=65, figsize=(14, 8))
    clf()
    gs = GridSpec(2, 6, height_ratios=[1, 1])

    subplot(gs[0, :2])
    mse_deg = mse*180/pi
    imshow(mse_deg, origin='lower left', aspect='auto',
           interpolation='nearest', vmin=0, extent=extent)
    labelit(error_func_name)
    cs = contour(mse_deg, origin='lower', aspect='auto',
                 levels=[15, 30, 45], colors='w',
                 extent=extent)
    clabel(cs, colors='w', inline=True, fmt='%d')
    

    subplot(gs[0, 2:4])
    imshow(bmf, origin='lower left', aspect='auto',
           interpolation='nearest',
           vmin=float(amin(dietz_fm)), vmax=float(amax(dietz_fm)),
           extent=extent)
    labelit('BMF')

    subplot(gs[0, 4:6])
    imshow(moddepth, origin='lower left', aspect='auto',
           interpolation='nearest', vmin=0, vmax=1,
           extent=extent)
    labelit('Modulation depth')

    subplot(gs[1, :3])
    plot(dietz_fm/Hz, all_peak_phase.T*180/pi, '-', color=(0.2, 0.7, 0.2, 0.2), label='Model (all)')
    plot(dietz_fm/Hz, best_peak_phase*180/pi, '-o', lw=2, label='Model (best)')
    plot(dietz_fm/Hz, dietz_phase*180/pi, '--r', label='Data')
    handles, labels = gca().get_legend_handles_labels()
    lab2hand = OrderedDict()
    for h, l in zip(handles, labels):
        lab2hand[l] = h
    legend(lab2hand.values(), lab2hand.keys(), loc='upper left')
    grid()
    ylim(0, 180)
    xlabel('Modulation frequency (Hz)')
    ylabel('Extracted phase (deg)')

    subplot(gs[1, 3:])
    plot(dietz_fm/Hz, all_norm_peak_fr.T, '-', color=(0.2, 0.7, 0.2, 0.2))
    plot(dietz_fm/Hz, best_norm_peak_fr, '-o')
    fm_interp = linspace(4, 64, 1000)
    fr_interp_func = interp1d(dietz_fm/Hz, best_norm_peak_fr, kind='quadratic')
    plot(fm_interp, fr_interp_func(fm_interp), ':k')

    ylim(0, 1)
    xlabel('Modulation frequency (Hz)')
    ylabel('Relative MTF')

    tight_layout()

### Population space

In [8]:
current_population_space_variables = {}

def plot_population_space(**kwds):
    # always use the same random seed for cacheing
    seed(34032483)
    # Get simple parameters
    detail_settings = dict(Low=100, Medium=1000, High=10000)
    N = detail_settings[kwds.pop('detail')]
    weighted = kwds.pop('weighted')
    error_func_name = kwds.pop('error_func')
    error_func = error_functions[error_func_name]
    dr_error_cutoff = kwds.pop('dr_error_cutoff')*pi/180
    dr_error_cutoff_plotting = kwds.pop('dr_error_cutoff_plotting')*pi/180
    if dr_error_cutoff_plotting>dr_error_cutoff:
        dr_error_cutoff_plotting = dr_error_cutoff
    dr_method_name = kwds.pop('dr_method')
    dr_method = dimensionality_reduction_methods[dr_method_name]
    dr_similarity = kwds.pop('dr_similarity')
    # Set up array keywords
    array_kwds = {}
    param_values = {}
    for k, (low, high) in kwds.items():
        v = rand(N)*(high-low)+low
        param_values[k] = v
        fm, v = meshgrid(dietz_fm, v) # fm and v have shape (N, len(dietz_fm))!
        fm.shape = fm.size
        v.shape = v.size
        array_kwds['fm'] = fm
        array_kwds[k] = v
    # Run the model
    t, out = simple_model(N*len(dietz_fm), array_kwds)
    (peak_phase, peak_fr, norm_peak_fr,
     mse, mse_norm, bmf, moddepth) = extract_peak_phase(N, t, out, error_func, weighted)
    # Properties of lowest MSE value
    idx_best = argmin(mse)
    best_peak_phase = peak_phase[idx_best, :]
    best_norm_peak_fr = norm_peak_fr[idx_best, :]
    bestvals = []
    for k in kwds.keys():
        v = param_values[k][idx_best]
        bestvals.append('%s=%.2f' % (k, v))
    print 'Best: ' + ', '.join(bestvals)
    # Plot the data
    if dr_method is not None:
        curfig = figure(dpi=65, figsize=(14, 10))
    else:
        curfig = figure(dpi=65, figsize=(14, 6))
    clf()
    if dr_method is not None:
        subplot(221)
    else:
        subplot(121)
    transp = clip(0.3*100./N, 0.01, 1)
    plot(dietz_fm/Hz, peak_phase.T*180/pi, '-', color=(0.4, 0.7, 0.4, transp), label='Model (all)')
    plot(dietz_fm/Hz, best_peak_phase*180/pi, '-ko', lw=2, label='Model (best)')
    plot(dietz_fm/Hz, dietz_phase*180/pi, '--r', label='Data')
    handles, labels = gca().get_legend_handles_labels()
    lab2hand = OrderedDict()
    for h, l in zip(handles, labels):
        lab2hand[l] = h
    legend(lab2hand.values(), lab2hand.keys(), loc='upper left')
    grid()
    ylim(0, 180)
    xlabel('Modulation frequency (Hz)')
    ylabel('Extracted phase (deg)')

    if dr_method is not None:
        subplot(222)
    else:
        subplot(122)
    lines = plot(dietz_fm/Hz, norm_peak_fr.T, '-')
    for i, line in enumerate(lines):
        line.set_color(cm.YlGnBu_r(mse_norm[i], alpha=transp))
    lines[argmin(mse)].set_alpha(1)
    lines[argmax(mse)].set_alpha(1)
    lines[argmin(mse)].set_label('Model (all, best MSE)')
    lines[argmax(mse)].set_label('Model (all, worst MSE)')
    plot(dietz_fm/Hz, best_norm_peak_fr, '-ko', lw=2)
    fm_interp = linspace(4, 64, 1000)
    fr_interp_func = interp1d(dietz_fm/Hz, best_norm_peak_fr, kind='quadratic')
    plot(fm_interp, fr_interp_func(fm_interp), ':k', lw=2)
    legend(loc='best')
    ylim(0, 1)
    xlabel('Modulation frequency (Hz)')
    ylabel('Relative MTF')
    
    if dr_method is not None:
        # Setup which indices we will keep
        keep_indices = mse<dr_error_cutoff
        # Setup the variable we will use for layout
        all_params = vstack(param_values.values()).T
        if dr_similarity=='Results':
            dr_args = peak_phase[keep_indices, :] # (paramset, fm)
        elif dr_similarity=='Parameters':
            dr_args = all_params[keep_indices, :] # (paramset, param)
        # Carry out the dimensionality reduction
        Xe = dr_method.fit_transform(dr_args)
        # Keep only the indices we'll plot
        Xe = Xe[mse[keep_indices]<dr_error_cutoff_plotting, :]
        # Calculate variables for colour/size
        idx_best = argmin(mse)
        distance_from_best = error_func(all_params[mse<dr_error_cutoff_plotting, :], all_params[idx_best, :][newaxis, :])
        distance_from_best = normed(distance_from_best)
        fit_to_data = mse_norm[mse<dr_error_cutoff_plotting]
        #mse_all_pairs = error_func(peak_phase[newaxis, keep_indices, :], peak_phase[keep_indices, newaxis, :], axis=2)

        subplot(223)
        scatter(Xe[:, 0], Xe[:, 1], c=fit_to_data, cmap=cm.viridis, s=(1+5*(1-distance_from_best))**2)
        xticks([])
        yticks([])
        title(('Dimensionality reduction: {dr_method_name} on {dr_similarity}\n'
               'Colour: fit to data, size: distance from best parameters').format(
                    dr_method_name=dr_method_name,
                    dr_similarity=dr_similarity))
        subplot(224)
        scatter(Xe[:, 0], Xe[:, 1], c=distance_from_best, cmap=cm.viridis, s=(1+5*(1-fit_to_data))**2)
        xticks([])
        yticks([])
        title(('Dimensionality reduction: {dr_method_name} on {dr_similarity}\n'
               'Colour: distance from best parameters, size: fit to data').format(
                    dr_method_name=dr_method_name,
                    dr_similarity=dr_similarity))
        current_population_space_variables.update(dict(
            dr_method=dr_method, dr_error_cutoff=dr_error_cutoff,
            dr_error_cutoff_plotting=dr_error_cutoff_plotting,
            keep_indices=keep_indices, 
            fit_to_data=fit_to_data, distance_from_best=distance_from_best, Xe=Xe,
            ))

    tight_layout()
    
    # Store current variables in global dictionary so we can re-use in notebook
    current_population_space_variables.update(dict(
        N=N, param_values=param_values,
        error_func=error_func,
        peak_phase=peak_phase, peak_fr=peak_fr,
        norm_peak_fr=norm_peak_fr,
        mse=mse, mse_norm=mse_norm,
        bmf=bmf, moddepth=moddepth,
        ))
    
if 0: # set to 1 when debugging plotting
    plot_population_space(detail='Low', show_tsne=False,
                          taui_ms=(0.1, 5), taue_ms=(0.1, 5), taua_ms=(0.1, 10),
                          level=(0, 0), alpha=(0, 0.99), beta=(0, 1),
                          gamma=(1, 1))

### Combined population / 2D map

In [9]:
population_summary_methods = {
    'Mean': mean,
    'Best': amin,
    }

def plot_population_map(selected_axes, **kwds):
    global curfig
    # always use the same random seed for cacheing
    seed(34032483)    
    # Set up ranges of variables, and generate arguments to pass to model function
    pop_summary_name = kwds.pop('pop_summary')
    pop_summary = population_summary_methods[pop_summary_name]
    error_func_name = kwds.pop('error_func')
    error_func = error_functions[error_func_name]
    detail_settings = dict(Low=(10, 20, 0.1),
                           Medium=(20, 100, 0.05),
                           High=(30, 500, 0.025))
    M, num_params, blur_width = detail_settings[kwds.pop('detail')]
    weighted = kwds.pop('weighted')
    smoothing = kwds.pop('smoothing')
    axis_ranges = dict((k, linspace(*(v+(M,)))) for k, v in kwds.items() if k in selected_axes)
    axis_ranges['fm'] = dietz_fm
    axis_ranges['temp'] = zeros(num_params)
    array_kwds = meshed_arguments(selected_axes+('temp', 'fm'), kwds, axis_ranges)
    del array_kwds['temp']
    vx, vy = selected_axes
    shape = array_kwds[vx].shape # shape will be (M, M, num_params, len(dietz_fm))
    N = array_kwds[vx].size
    for k, (low, high) in kwds.items():
        if k not in selected_axes:
            array_kwds[k] = rand(N)*(high-low)+low
        array_kwds[k].shape = N
    array_kwds['fm'].shape = N
    # Run the model
    t, out = simple_model(N, array_kwds)
    (all_peak_phase, all_peak_fr, all_norm_peak_fr,
     mse, mse_norm, bmf, moddepth) = extract_peak_phase(M*M*num_params, t, out,
                                                        error_func, weighted)
    # Analyse the data
    peak_phase = all_peak_phase.reshape((M, M, num_params, -1))
    norm_peak_fr = all_norm_peak_fr.reshape((M, M, num_params, -1))
    bmf.shape = moddepth.shape = mse.shape = mse_norm.shape = (M, M, num_params)
    mse = pop_summary(mse, axis=2)
    # Plot the data
    if smoothing:
        mse = gaussian_filter(mse, blur_width*M, mode='nearest')
        mse = zoom(mse, 100./M, order=1)
    extent = (kwds[vx]+kwds[vy])
    curfig = figure()#dpi=65, figsize=(14, 8))
    mse_deg = mse*180/pi
    imshow(mse_deg, origin='lower left', aspect='auto',
           interpolation='nearest', vmin=0, extent=extent)
    xlabel(sliders[vx].description)
    ylabel(sliders[vy].description)
    cb = colorbar()
    cb.set_label(error_func_name, rotation=270, labelpad=20)
    cs = contour(mse_deg, origin='lower',
                 levels=[15, 30, 45], colors='w',
                 extent=extent)
    clabel(cs, colors='w', inline=True, fmt='%d')
    tight_layout()

## GUI

In [10]:
sliders = OrderedDict([
    (spec['name'],
     ipw.FloatSlider(description=spec['description'], min=spec['min'], max=spec['max'],
                     step=spec['step'], value=spec['value'])) for spec in parameter_specs])
range_sliders = OrderedDict([
    (spec['name'],
     ipw.FloatRangeSlider(description=spec['description'], min=spec['min'], max=spec['max'],
                     step=spec['step'], value=(spec['min'], spec['max']))) for spec in parameter_specs])

detail_slider = ipw.Dropdown(description="Detail",
                             options=["Low", "Medium", "High"],
                             value='Low')

error_func_dropdown = ipw.Dropdown(description="Error function", options=error_functions.keys())

weighted_widget = ipw.Checkbox(description="Use weighted mean phase instead of peak", value=False)

def full_width_widget(widget):
    widget.layout.width = '95%'
    widget.style = {'description_width': '30%'}
    return widget

for slider in sliders.values()+range_sliders.values()+[detail_slider,
                                                       error_func_dropdown,
                                                       weighted_widget,
                                                       ]:
    full_width_widget(slider)

def savecurfig(fname):
    curfig.savefig(fname)
widget_savefig = save_fig_widget(savecurfig)

#########################################################################
# Model 1: MSE/MTF 2d maps
vars_mse_mtf = OrderedDict((k, v.description) for k, v in sliders.items())
vs2d_mse_mtf = VariableSelector(vars_mse_mtf, ['Horizontal axis', 'Vertical axis'], title=None,
                                initial={'Horizontal axis': 'alpha',
                                         'Vertical axis': 'taua_ms'})
options2d_mse_mtf = {'var': vs2d_mse_mtf.widgets_vertical}

current_map2d_widgets = {}

def map2d(runmodel, vs2d):
    def f():
        params = vs2d.merge_selected(range_sliders, sliders)
        current_map2d_widgets.clear()
        current_map2d_widgets.update(params)
        params['detail'] = detail_slider
        params['interpolate_bmf'] = full_width_widget(ipw.Checkbox(description="Interpolate BMF",
                                                                   value=True))
        params['weighted'] = weighted_widget
        params['error_func'] = error_func_dropdown
        def plotter(**kwds):
            vx = vs2d.selection['Horizontal axis']
            vy = vs2d.selection['Vertical axis']
            return plot_map2d_mse_mtf((vx, vy), **kwds)
        i = ipw.interactive(plotter, dict(manual=True, manual_name="Run simulation"), **params)
        return no_continuous_update(i)
    return f

#########################################################################
# Model 2: population space    
    
def population_space():
    params = range_sliders.copy()
    params['weighted'] = weighted_widget
    params['detail'] = detail_slider
    param_groups = OrderedDict([('', params.copy())])
    # Dimensionality reduction parameters
    params['dr_method'] = ipw.Dropdown(description="Method",
                                       options=dimensionality_reduction_methods.keys(),
                                       value='None')
    params['dr_similarity'] = ipw.Dropdown(description="Similarity using",
                                           options=["Results", "Parameters"])
    params['error_func'] = error_func_dropdown
    params['dr_error_cutoff'] = ipw.FloatSlider(description="Error cutoff for dimensionality reduction (deg)",
                                                min=0, max=180, step=1, value=180)
    params['dr_error_cutoff_plotting'] = ipw.FloatSlider(description="Error cutoff for plotting (deg)",
                                                         min=0, max=180, step=1, value=180)
    dr_names = ['dr_method', 'dr_similarity', 'error_func',
                'dr_error_cutoff', 'dr_error_cutoff_plotting']
    param_groups['Dimensionality reduction'] = OrderedDict([(name, params[name]) for name in dr_names])
    for w in params.values():
        full_width_widget(w)
    # setup GUI
    i = grouped_interactive(plot_population_space, param_groups, manual_name="Run simulation")
    return i

#########################################################################
# Model 3: Combined population / 2D map
vars_pop_map = OrderedDict((k, v.description) for k, v in sliders.items())
vs2d_pop_map = VariableSelector(vars_pop_map, ['Horizontal axis', 'Vertical axis'], title=None,
                                initial={'Horizontal axis': 'alpha',
                                         'Vertical axis': 'beta'})
options2d_pop_map = {'var': vs2d_pop_map.widgets_vertical}

current_pop_map_widgets = {}

def population_map():
    params = range_sliders.copy()
    current_pop_map_widgets.clear()
    current_pop_map_widgets.update(params)
    params['pop_summary'] = full_width_widget(
        ipw.Dropdown(description="Population summary method",
                     options=population_summary_methods.keys(),
                     value="Best"))
    params['detail'] = detail_slider
    params['weighted'] = weighted_widget
    params['smoothing'] = full_width_widget(
        ipw.Checkbox(description="Image smoothing", value=True))
    params['error_func'] = error_func_dropdown
    def plotter(**kwds):
        vx = vs2d_pop_map.selection['Horizontal axis']
        vy = vs2d_pop_map.selection['Vertical axis']
        return plot_population_map((vx, vy), **kwds)
    i = ipw.interactive(plotter, dict(manual=True, manual_name="Run simulation"), **params)
    return no_continuous_update(i)

#########################################################################
# Construct and show GUI

models = [('2d map', map2d(simple_model, vs2d_mse_mtf), options2d_mse_mtf,
               [load_save_parameters_widget(current_map2d_widgets, 'saved_params_simple_map2d'),
                widget_savefig, progress_slider]),
          ('Population', population_space, {},
               [load_save_parameters_widget(range_sliders, 'saved_params_simple_population'),
                widget_savefig, progress_slider]),
          ('Population/map', population_map, options2d_pop_map,
               [load_save_parameters_widget(current_pop_map_widgets, 'saved_params_simple_popmap'),
                widget_savefig, progress_slider]),
         ]

# Create model explorer, and jump immediately to results page
modex = model_explorer(models)
modex.widget_model_type.value = 'Population/map'
modex.tabs.selected_index = 1
display(modex)

VkJveChjaGlsZHJlbj0oVGFiKGNoaWxkcmVuPShWQm94KGNoaWxkcmVuPShSYWRpb0J1dHRvbnMoZGVzY3JpcHRpb249dSdNb2RlbCB0eXBlJywgaW5kZXg9Miwgb3B0aW9ucz0oJzJkIG1hcCfigKY=


## Investigations

### Dimensionality reduction

Cells below here won't run without error unless you have carried out dimensionality reduction in the GUI above.

In [11]:
globals().update(current_population_space_variables)
figure()
subplot(221)
scatter(Xe[:, 0], Xe[:, 1], c=fit_to_data, cmap=cm.viridis, s=(1+5*(1-distance_from_best))**2)
subplot(222)
scatter(Xe[:, 0], Xe[:, 1], c=distance_from_best, cmap=cm.viridis, s=(1+5*(1-fit_to_data))**2)
subplot(223)
hist(mse[mse<dr_error_cutoff_plotting]*180/pi)

<IPython.core.display.Javascript object>

NameError: name 'Xe' is not defined

In [None]:
varying_param_values = {}
param_value_index = {}
for j, (k, v) in enumerate(param_values.items()):
    param_value_index[k] = j
    if amin(v)!=amax(v):
        varying_param_values[k] = v
for i in range(2):
    s = 'Component %d (explains %d%%):\n\t' % (i, 100*dr_method.explained_variance_ratio_[i])
    cpts = []
    for k in varying_param_values.keys():
        c = dr_method.components_[i, param_value_index[k]]
        if c:
            cpts.append('%s = %.3f' % (k, c))
    s += ',\n\t'.join(cpts)
    print s
print 'Total explained variance %d%%' % (100*sum(dr_method.explained_variance_ratio_))

In [None]:
figure(figsize=(10, 4))
for j, k in enumerate(varying_param_values.keys()):
    subplot(2, 4, j+1)
    title(k)
    hist(param_values[k][mse<dr_error_cutoff_plotting])
tight_layout()

In [None]:
figure()
I = array([param_value_index[k] for k in varying_param_values.keys()])
cov = dr_method.get_covariance()[I[:, newaxis], I[newaxis, :]]
s = array([std(v) for v in varying_param_values.values()])
cov = cov/(s[:, newaxis]*s[newaxis, :])
cov[arange(cov.shape[0]), arange(cov.shape[0])] = nan
imshow((cov), origin='lower left', aspect='auto', interpolation='nearest')
xticks(range(len(varying_param_values)), varying_param_values.keys())
yticks(range(len(varying_param_values)), varying_param_values.keys())
colorbar()

### Degeneracy

Generate a curve from good solution for inhibition only to good solution for adaptation only, and plot it.

In [None]:
# N = 1000; error_cutoff = 35*pi/180
N = 5000; error_cutoff = 25*pi/180
error_func = error_functions['Max error']
weighted = False

def uniform_random(N, min, max):
    return rand(N)*(max-min)+min

seed(3489234)

# These parameters are fixed for all parameters
fixed_params = dict(
    fc_Hz=0,
    tauihc_ms=0,
    gamma=0.9,
    )

@mem.cache
def runsim(N, **kwds):
    # Generate some random parameters
    array_params = dict(
        taue_ms=uniform_random(N, 0.1, 1),
        taui_ms=uniform_random(N, 3, 7),
        taua_ms=uniform_random(N, 0.5, 5),
        # alpha and beta are two overlapping rectangles
        alpha=hstack((uniform_random(N/2, 0, 0.99), uniform_random(N/2, 0.8, 0.99))),
        beta=hstack((uniform_random(N/2, 0.6, 1.1), uniform_random(N/2, 0, 1.2))),
        level=uniform_random(N, 0, 20),
        )
    
    sim_fixed_params = fixed_params.copy()
    for k, v in kwds.items():
        del array_params[k]
        sim_fixed_params[k] = v
        
    # Setup vectorisation
    array_kwds = sim_fixed_params.copy()
    for k, v in array_params.items():
        fm, v = meshgrid(dietz_fm, v) # fm and v have shape (N, len(dietz_fm))!
        fm.shape = fm.size
        v.shape = v.size
        array_kwds['fm'] = fm
        array_kwds[k] = v

    # Run the model
    t, out = simple_model(N*len(dietz_fm), array_kwds)
    return array_params, extract_peak_phase(N, t, out, error_func, weighted)

# Step 1, find best adaptation only parameters:
array_params_adapt, (peak_phase_adapt, peak_fr, norm_peak_fr, mse, mse_norm, bmf, moddepth) = runsim(N/2, beta=0)
idx_adapt = argmin(mse)
print 'Best adaptation only error %.2f, number of points found %d' % (mse[idx_adapt]*180/pi, sum(mse<error_cutoff))
for k, v in array_params_adapt.items():
    print '\t', k, v[idx_adapt]
# Step 2, find best inhibition only parameters:
array_params_inh, (peak_phase_inh, peak_fr, norm_peak_fr, mse, mse_norm, bmf, moddepth) = runsim(N/2, alpha=0)
idx_inh = argmin(mse)
print 'Best inhibition only error %.2f, number of points found %d' % (mse[idx_inh]*180/pi, sum(mse<error_cutoff))
for k, v in array_params_inh.items():
    print '\t', k, v[idx_inh]
# Step 3, find best parameters everywhere:
array_params, (peak_phase, peak_fr, norm_peak_fr, mse, mse_norm, bmf, moddepth) = runsim(N)
i = argmin(mse)
print 'Best error %.2f, number of points found %d' % (mse[i]*180/pi, sum(mse<error_cutoff))
for k, v in array_params.items():
    print '\t', k, v[i]
# Step 4, generate a graph of neighbourhoods
var_to_idx = dict((k, i) for i, k in enumerate(array_params.keys()))
I, = (mse<error_cutoff).nonzero()
M = len(I)+2 # number of points
X = zeros((M, len(array_params)))
X[0, [var_to_idx[k] for k in array_params_inh.keys()]] = [v[idx_inh] for v in array_params_inh.values()]
X[1, [var_to_idx[k] for k in array_params_adapt.keys()]] = [v[idx_inh] for v in array_params_adapt.values()]
for k, v in array_params.items():
    X[2:, var_to_idx[k]] = v[I]
Xmax = amax(X, axis=0)[newaxis, :]
Xmin = amin(X, axis=0)[newaxis, :]
Y = (X-Xmin)/(Xmax-Xmin)
D = error_func(Y[newaxis, :, :], Y[:, newaxis, :], axis=2)
# Step 5, find shortest paths from adaptation only to inhibition only
# We want to take the smallest steps possible still allowing for a path to exist
from scipy.sparse.csgraph import dijkstra
def find_sp(eps):
    delta, pred = dijkstra(D<=eps, indices=[0, 1], return_predecessors=True)
    return pred
def has_sp(eps):
    pred = find_sp(eps)
    return pred[0, 1]>=0
# use bisection until we have accuracy 0.01
eps_acc = 0.01
low, high = 0.0, 1.0
while high-low>eps_acc:
    mid = (low+high)/2.
    if has_sp(mid):
        high = mid
    else:
        low = mid
eps = high
# extract the corresponding path
pred = find_sp(eps)
path = [1]
while path[-1]!=0:
    path.append(pred[0, path[-1]])
print path
# plot the path
figure(dpi=75, figsize=(10, 4))
px = py = None
for i, idx in enumerate(path):
    c = cm.viridis(1.0*i/(len(path)-1))
    if idx==0:
        pp = peak_phase_inh[idx_inh, :]
        lab = 'Inhibition only'
        x = 0
        y = array_params_inh['beta'][idx_inh]
    elif idx==1:
        pp = peak_phase_adapt[idx_adapt, :]
        lab = 'Adaptation only'
        x = array_params_adapt['alpha'][idx_adapt]
        y = 0
    else:
        j = I[idx-2]
        pp = peak_phase[j, :]
        lab = None
        x = array_params['alpha'][j]
        y = array_params['beta'][j]
    subplot(122)
    if px is not None:
        plot([px, x], [py, y], '-', c=c, zorder=-1)
    px = x
    py = y
    plot([x], [y], 'o', c=c)
    subplot(121)
    plot(dietz_fm, pp*180/pi, c=c, label=lab)
subplot(121)
xlabel(r'Modulation frequency $f_m$ (Hz)')
ylabel('Extracted phase (deg)')
plot(dietz_fm, dietz_phase*180/pi, '--r', lw=2, label='Data')
grid()
legend(loc='best')
ylim(0, 180)
subplot(122)
xlabel(r'Adaptation strength $\alpha$')
ylabel(r'Inhibition strength $\beta$')
tight_layout()