# Maps of simple model

TODO:

* loading params changes each param one by one. Can maybe use hold_trait_notifications but not entirely simple.
* visualise the space of good parameter values using Jonny's thing or t-SNE
* Show only good params in the MTF plot?

In [None]:
from brian2 import *
from model_explorer_jupyter import *
import ipywidgets as ipw
import joblib
from collections import OrderedDict
from scipy.interpolate import interp1d

BrianLogger.suppress_name('resolution_conflict')

def normed(X, *args):
    m = max(amax(abs(Y)) for Y in (X,)+args)
    return X/m

progress_slider, update_progress = brian2_progress_reporter()

mem = joblib.Memory(location="joblib", verbose=0)

dietz_fm = array([4, 8, 16, 32, 64])*Hz
dietz_phase = array([37, 40, 62, 83, 115])*pi/180

In [None]:
#mem.cache
def simple_model(N, params):
    eqs = '''
    A = (gain*0.5*(1-cos(2*pi*fm*t)))**gamma : 1
    dQ/dt = -k*Q*A+R*(1-Q) : 1
    AQ = A*Q : 1
    dAe/dt = (AQ-Ae)/taue : 1
    dAi/dt = (AQ-Ai)/taui : 1
    out = Ae-beta*Ai : 1
    gain = 10**(level/20.) : 1
    R = (1-alpha)/taua : Hz
    k = alpha/taua : Hz
    fm : Hz
    taue = taue_ms*ms : second
    taui = taui_ms*ms : second
    taua = taua_ms*ms : second
    taue_ms : 1
    taui_ms : 1
    taua_ms : 1
    alpha : 1
    beta : 1
    gamma : 1
    level : 1
    '''
    G = NeuronGroup(N, eqs, method='euler', dt=0.1*ms)
    G.set_states(params)
    G.Q = 1
    M = StateMonitor(G, 'out', record=True)
    Network(G, M).run(1*second, report=update_progress, report_period=1*second)
    return M.t[:], M.out[:]

In [None]:
parameter_specs = [
    dict(name='taue_ms',
         description=r"Excitatory filtering time constant $\tau_e$ (ms)",
         min=0.1, max=10, step=0.1, value=0.1),
    dict(name='taui_ms',
         description=r"Inhibitory filtering time constant $\tau_i$ (ms)",
         min=0.1, max=10, step=0.1, value=0.5),
    dict(name='taua_ms',
         description=r"Adaptation time constant $\tau_a$ (ms)",
         min=0.1, max=10, step=0.1, value=5),
    dict(name='alpha',
         description=r"Adaptation strength $\alpha$",
         min=0, max=0.99, step=0.01, value=0.8),
    dict(name='beta',
         description=r"Inhibition strength $\beta$",
         min=0, max=2, step=0.01, value=1.0),
    dict(name='gamma',
         description=r"Compression power $\gamma$",
         min=0.1, max=1, step=0.01, value=1.0),
    dict(name='level',
         description=r"Relative sound level $L$ (dB)",
         min=-90, max=90, step=5, value=0),
    ]

In [None]:
sliders = OrderedDict([
    (spec['name'],
     ipw.FloatSlider(description=spec['description'], min=spec['min'], max=spec['max'],
                     step=spec['step'], value=spec['value'])) for spec in parameter_specs])
range_sliders = OrderedDict([
    (spec['name'],
     ipw.FloatRangeSlider(description=spec['description'], min=spec['min'], max=spec['max'],
                     step=spec['step'], value=(spec['min'], spec['max']))) for spec in parameter_specs])

detail_slider = ipw.Dropdown(description="Detail",
                             options=["Low", "Medium", "High"],
                             value='Low')

for slider in sliders.values()+range_sliders.values():
    slider.layout.width = '95%'
    slider.style = {'description_width': '30%'}

def savecurfig(fname):
    curfig.savefig(fname)
widget_savefig = save_fig_widget(savecurfig)

#########################################################################
# Model 1: MSE/MTF 2d maps
vars_mse_mtf = OrderedDict((k, v.description) for k, v in sliders.items())
vs2d_mse_mtf = VariableSelector(vars_mse_mtf, ['Horizontal axis', 'Vertical axis'], title=None,
                                initial={'Horizontal axis': 'alpha',
                                         'Vertical axis': 'taua_ms'})
options2d_mse_mtf = {'var': vs2d_mse_mtf.widgets_vertical}

def plot2d(modelfunc, vs2d):
    def plotter(**kwds):
        global curfig
        # Set up ranges of variables, and generate arguments to pass to model function
        interpolate_bmf = kwds.pop('interpolate_bmf')
        detail_settings = dict(Low=10, Medium=40, High=100)
        M = detail_settings[kwds.pop('detail')]
        axis_ranges = dict((k, linspace(*(v+(M,)))) for k, v in kwds.items() if k in vs2d.selected)
        axis_ranges['fm'] = dietz_fm
        array_kwds = meshed_arguments(vs2d.selected+('fm',), kwds, axis_ranges)
        vx = vs2d.selection['Horizontal axis']
        vy = vs2d.selection['Vertical axis']
        shape = array_kwds[vx].shape
        N = array_kwds[vx].size
        array_kwds[vx].shape = N
        array_kwds[vy].shape = N
        array_kwds['fm'].shape = N
        # Run the model
        t, out = modelfunc(N, array_kwds)
        # Analyse the data
        array_kwds[vx].shape = shape
        array_kwds[vy].shape = shape
        array_kwds['fm'].shape = shape
        out.shape = shape+(len(t),)
        out[out<0] = 0
        fm_flat = array_kwds['fm'][:, :, :, newaxis]
        t_flat = t[newaxis, newaxis, newaxis, :]
        n_flat = array(around(0.5*second*fm_flat), dtype=int)
        idx = t_flat<n_flat/fm_flat
        out[idx] = 0
        peak = t[argmax(out, axis=3)]
        peak_phase = (peak*2*pi*array_kwds['fm']) % (2*pi)
        peak_fr = amax(out, axis=3)
        norm_peak_fr = peak_fr/amax(peak_fr, axis=2)[:, :, newaxis]
        dietz_phase_flat = dietz_phase[newaxis, newaxis, :]
        mse = sum((dietz_phase[newaxis, newaxis, :]-peak_phase)**2, axis=2) # sum over fm
        all_peak_phase = peak_phase.reshape((M*M, -1))
        all_norm_peak_fr = norm_peak_fr.reshape((M*M, -1))
        bmf = asarray(dietz_fm)[argmax(norm_peak_fr, axis=2)]
        moddepth = 1-amin(norm_peak_fr, axis=2)
        # interpolated bmf
        if interpolate_bmf:
            fm_interp = linspace(4, 64, 1000)
            for cx in range(M):
                for cy in range(M):
                    cur_fr = norm_peak_fr[cy, cx, :]
                    fr_interp_func = interp1d(dietz_fm, cur_fr, kind='quadratic')
                    bmf[cy, cx] = fm_interp[argmax(fr_interp_func(fm_interp))]                    
        # Properties of lowest MSE value
        idx_best_y, idx_best_x = unravel_index(argmin(mse), mse.shape)
        xbest = axis_ranges[vx][idx_best_x]
        ybest = axis_ranges[vy][idx_best_y]
        best_peak_phase = peak_phase[idx_best_y, idx_best_x, :]
        best_norm_peak_fr = norm_peak_fr[idx_best_y, idx_best_x, :]
        print 'Best: {vx} = {xbest}, {vy} = {ybest}'.format(vx=vx, vy=vy, xbest=xbest, ybest=ybest)
        # Plot the data
        extent = (kwds[vx]+kwds[vy])
        def labelit(titletext):
            plot([xbest], [ybest], '+w')
            title(titletext)
            xlabel(sliders[vx].description)
            ylabel(sliders[vy].description)
            cb = colorbar()
            cb.set_label(titletext, rotation=270, labelpad=20)
        
        curfig = figure(figsize=(14, 8))
        clf()
        gs = GridSpec(2, 6, height_ratios=[1, 1])
        
        subplot(gs[0, :2])
        imshow(mse, origin='lower left', aspect='auto',
               interpolation='nearest', vmin=0, extent=extent)
        labelit('MSE')
        
        subplot(gs[0, 2:4])
        imshow(bmf, origin='lower left', aspect='auto',
               interpolation='nearest',
               vmin=float(amin(dietz_fm)), vmax=float(amax(dietz_fm)),
               extent=extent)
        labelit('BMF')
        
        subplot(gs[0, 4:6])
        imshow(moddepth, origin='lower left', aspect='auto',
               interpolation='nearest', vmin=0, vmax=1,
               extent=extent)
        labelit('Modulation depth')
        
        subplot(gs[1, :3])
        plot(dietz_fm/Hz, all_peak_phase.T*180/pi, '-', color=(0.2, 0.7, 0.2, 0.2), label='Model (all)')
        plot(dietz_fm/Hz, best_peak_phase*180/pi, '-o', lw=2, label='Model (best)')
        plot(dietz_fm/Hz, dietz_phase*180/pi, '--r', label='Data')
        handles, labels = gca().get_legend_handles_labels()
        lab2hand = OrderedDict()
        for h, l in zip(handles, labels):
            lab2hand[l] = h
        legend(lab2hand.values(), lab2hand.keys(), loc='upper left')
        grid()
        ylim(0, 180)
        xlabel('Modulation frequency (Hz)')
        ylabel('Extracted phase (deg)')
        
        subplot(gs[1, 3:])
        plot(dietz_fm/Hz, all_norm_peak_fr.T, '-', color=(0.2, 0.7, 0.2, 0.2))
        plot(dietz_fm/Hz, best_norm_peak_fr, '-o')
        fm_interp = linspace(4, 64, 1000)
        fr_interp_func = interp1d(dietz_fm/Hz, best_norm_peak_fr, kind='quadratic')
        plot(fm_interp, fr_interp_func(fm_interp), ':k')
        
        ylim(0, 1)
        xlabel('Modulation frequency (Hz)')
        ylabel('Relative MTF')
        
        tight_layout()
        
    return plotter

current_map2d_widgets = {}

def map2d(runmodel, vs2d):
    def f():
        params = vs2d.merge_selected(range_sliders, sliders)
        current_map2d_widgets.clear()
        current_map2d_widgets.update(params)
        params['detail'] = detail_slider
        params['interpolate_bmf'] = ipw.Checkbox(description="Interpolate BMF", value=True)
        i = ipw.interactive(plot2d(runmodel, vs2d), **params)
        return no_continuous_update(i)
    return f

#########################################################################
# Model 2: population space
def plot_population_space(**kwds):
    # Get simple parameters
    interpolate_bmf = kwds.pop('interpolate_bmf')
    detail_settings = dict(Low=100, Medium=1000, High=10000)
    N = detail_settings[kwds.pop('detail')]
    # Set up array keywords
    array_kwds = {}
    param_values = {}
    for k, (low, high) in kwds.items():
        v = rand(N)*(high-low)+low
        param_values[k] = v
        fm, v = meshgrid(dietz_fm, v) # fm and v have shape (N, len(dietz_fm))!
        fm.shape = fm.size
        v.shape = v.size
        array_kwds['fm'] = fm
        array_kwds[k] = v
    # Run the model
    t, out = simple_model(N*len(dietz_fm), array_kwds)
    out.shape = (N, len(dietz_fm), len(t))
    out[out<0] = 0
    fm = dietz_fm
    n = array(around(0.5*second*fm), dtype=int)
    idx = (t[newaxis, newaxis, :]<(n/fm)[newaxis, :, newaxis])+zeros(out.shape, dtype=bool)
    out[idx] = 0
    peak = t[argmax(out, axis=2)] # shape (N, n_fm)
    peak_phase = (peak*2*pi*fm[newaxis, :]) % (2*pi) # shape (N, n_fm)
    peak_fr = amax(out, axis=2) # shape (N, n_fm)
    norm_peak_fr = peak_fr/amax(peak_fr, axis=1)[:, newaxis]
    mse = sum((dietz_phase[newaxis, :]-peak_phase)**2, axis=1) # sum over fm, mse has shape N
    # Properties of lowest MSE value
    idx_best = argmin(mse)
    best_peak_phase = peak_phase[idx_best, :]
    best_norm_peak_fr = norm_peak_fr[idx_best, :]
    bestvals = []
    for k in kwds.keys():
        v = param_values[k][idx_best]
        bestvals.append('%s=%.2f' % (k, v))
    print 'Best: ' + ', '.join(bestvals)
    # Plot the data
    curfig = figure(figsize=(14, 5))
    clf()
    subplot(121)
    transp = clip(0.3*100./N, 0.01, 1)
    plot(dietz_fm/Hz, peak_phase.T*180/pi, '-', color=(0.4, 0.7, 0.4, transp), label='Model (all)')
    plot(dietz_fm/Hz, best_peak_phase*180/pi, '-o', lw=2, label='Model (best)')
    plot(dietz_fm/Hz, dietz_phase*180/pi, '--r', label='Data')
    handles, labels = gca().get_legend_handles_labels()
    lab2hand = OrderedDict()
    for h, l in zip(handles, labels):
        lab2hand[l] = h
    legend(lab2hand.values(), lab2hand.keys(), loc='upper left')
    grid()
    ylim(0, 180)
    xlabel('Modulation frequency (Hz)')
    ylabel('Extracted phase (deg)')

    subplot(122)
    plot(dietz_fm/Hz, norm_peak_fr.T, '-', color=(0.4, 0.7, 0.4, transp))
    plot(dietz_fm/Hz, best_norm_peak_fr, '-o')
    fm_interp = linspace(4, 64, 1000)
    fr_interp_func = interp1d(dietz_fm/Hz, best_norm_peak_fr, kind='quadratic')
    plot(fm_interp, fr_interp_func(fm_interp), ':C0')

    ylim(0, 1)
    xlabel('Modulation frequency (Hz)')
    ylabel('Relative MTF')

    tight_layout()

    
def population_space():
    params = range_sliders.copy()
    params['detail'] = detail_slider
    params['interpolate_bmf'] = ipw.Checkbox(description="Interpolate BMF", value=True)
    i = ipw.interactive(plot_population_space, **params)
    return no_continuous_update(i)

#########################################################################
# Construct and show GUI

models = [('MSE/MTF 2d map', map2d(simple_model, vs2d_mse_mtf),
               options2d_mse_mtf,
               [load_save_parameters_widget(current_map2d_widgets, 'saved_params_simple_map2d'),
                widget_savefig, progress_slider]),
          ('Population', population_space, {},
               [load_save_parameters_widget(range_sliders, 'saved_params_simple_population'),
                widget_savefig, progress_slider]),
         ]

# Create model explorer, and jump immediately to results page
modex = model_explorer(models)
modex.widget_model_type.value = 'Population'
modex.tabs.selected_index = 1
display(modex)