In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import nengo
from nengo.utils.ensemble import sorted_neurons
from nengo_extras.plot_spikes import preprocess_spikes

import sys
sys.path.append('../model')

import benchmark

def 𐌈(**kwargs):
    return kwargs

In [None]:
T = 4.0

detailed_kwargs = 𐌈(
        mode="two_populations_dales_principle",
        use_spatial_constraints=True,
        n_pcn_golgi_convergence=100,
        n_pcn_granule_convergence=5,
        n_granule_golgi_convergence=100,
        n_golgi_granule_convergence=5,
        n_golgi_golgi_convergence=100,
        n_granule=10000,
        n_golgi=100,
)

In [None]:
np.random.seed(4192)
res1 = benchmark.build_and_run_test_network(benchmark.pulse_input(0.5, 0.5), T=T, probe_granule_decoded=True, **detailed_kwargs)

In [None]:
np.random.seed(4192)
res2 = benchmark.build_and_run_test_network(benchmark.white_noise_input(5.0), T=T, probe_granule_decoded=True, **detailed_kwargs)

In [None]:
np.random.seed(4192)
res3 = benchmark.build_and_run_test_network(benchmark.pulse_input(0.5, 0.5), T=1.0, probe_spatial_data=True, **detailed_kwargs)

In [None]:
def hide_spines(ax):
    for side in ['top', 'right', 'bottom', 'left']:
        ax.spines[side].set_visible(False)
    return ax

def make_xs_tars(ts, xs, theta, delays):
    dt = (ts[-1] - ts[0]) / (len(ts) - 1)

    # Compute how well the delayed signal can be approximated
    xs_tars = []
    for i, delay in enumerate(delays):
        # Shift the input signal by n samples
        n = int(theta * delay / dt)
        n0, n1 = 0, max(0, len(xs) - n)
        xs_tars.append(np.concatenate((np.zeros(n), xs[n0:n1])))
    return np.array(xs_tars)

def decode_delay(xs_tar, As, sigma=0.1, seed=58791, dt=1e-3):
    # Check some dimensions
    assert xs_tar.ndim == 1
    assert As.ndim == 2
    assert xs_tar.shape[0] == As.shape[0]

    # Compute this for a random subset of neurons
    n0 = int(0.0 / dt)
    if As.shape[1] > 1000:
        all_idcs = np.arange(As.shape[1], dtype=int)
        idcs = np.random.RandomState(seed).choice(all_idcs,
                                                  1000,
                                                  replace=False)
        As = As[:, idcs]

    Asp, xs_tarp = As[n0:], xs_tar[n0:]
    reg = Asp.shape[0] * np.square(sigma * np.max(Asp))
    D = np.linalg.lstsq(Asp.T @ Asp + reg * np.eye(Asp.shape[1]),
                        Asp.T @ xs_tarp,
                        rcond=None)[0]

    return As @ D


def unfilter_spike_train(ts, As):
    dt = (ts[-1] - ts[0]) / (len(ts) - 1)
    dAs = (As[1:] - As[:-1])
    spikes = np.zeros_like(As)
    spikes[1:][dAs > 5] = 1 / dt
    return spikes


def rasterplot(ax, ts, A, **style):
    N, n = A.shape
    for i in range(n):
        for t in ts[np.where(A[:, i] != 0)]:
            ax.plot([t, t], [i + 0.5, i + 1.5], zorder=-100, solid_capstyle="butt", **style)
    ax.set_ylim(0.5, n + 0.5)

def plot_decoding_example(ts, xs, ys, As, theta, name, ax1_ylim, ax2_ylim, ax3_ylim, show_t_on=False):
    fig = plt.figure(figsize=(4.2, 2.4))
    ax1 = hide_spines(plt.subplot2grid((9, 1), (0, 0)))
    ax1.plot(ts, xs, color='k', clip_on=False, linewidth=1.0)
    ax1.set_ylim(*ax1_ylim)
    ax1.set_yticks([])
    ax1.set_xlim(0, T)
    ax1.set_xticks([])
    ax1.text(1.0, 1.8, 'Input $u(t)$', 
         ha='right', va='top', fontsize=8, transform=ax1.transAxes,)

    ax1.plot([0.2, 0.6], [-0.5, -0.5], color='k', linewidth=1.5, clip_on=False, solid_capstyle="butt")
    ax1.text(0.4, ax1_ylim[0] - 0.7, '$\\theta = 0.4\\,\\mathrm{s}$', 
         ha='center', va='top', fontsize=8)

    if show_t_on:
        ax1.plot([0.6, 0.8], [1.25, 1.25], color='k', linewidth=1.5, clip_on=False, solid_capstyle="butt")
        ax1.text(0.7, 1.27, '$t_\\mathrm{on} = 0.2\\,\\mathrm{s}$', 
             ha='center', va='bottom', fontsize=8)
    else:
        ax1.text(1.75, 1.0, 'Bandwidth $B = 5.0\\,\\mathrm{Hz}$', 
             ha='center', va='bottom', fontsize=8)

    height = 1 #int(np.round(0.75 * (ax1_ylim[1] - ax1_ylim[0])))
    ax1.plot([0.0, 0.0], [0.0, height], color='k', linewidth=1.5, clip_on=False, solid_capstyle="butt")
    ax1.text(0.025, 0.4, '${}$'.format(height), 
         ha='left', va='center', fontsize=8)
    ax1.axhline(0, linestyle='--', linewidth=0.5, color='black', zorder=100)
    
    ax1.text(0.0, 2.0, '$\\mathbf{{{}}}$'.format(name), 
         ha='left', va='top', fontsize=12, transform=ax1.transAxes,)

    ax2 = hide_spines(plt.subplot2grid((9, 1), (1, 0), sharex=ax1, rowspan=4))
    ax2.set_ylim(*ax2_ylim)
    ax2.set_yticks([])
    ax2.set_xlim(0, T)

#    ax2.axhline(0, linestyle='--', linewidth=0.5, color='black', zorder=100)
#    ax2.plot(ts, ys)

    ax2.text(1.0, 0.8, 'Granule cell activities $\\mathbf{a}(t)$', 
         ha='right', va='top', fontsize=8, transform=ax2.transAxes,)

#    ax2.plot([0.0, 0.0], [0.0, 1.0], color='k', linewidth=1.5, clip_on=False, solid_capstyle="butt")
#    ax2.text(0.025, 0.45, '$1$', 
#         ha='left', va='center', fontsize=8)

    # Compute the spike-train underlying As
    spikes = unfilter_spike_train(ts, As)

    # Randomly select n_neurons neurons
    n_neurons = 40
    spikes = spikes[:, np.random.choice(np.arange(As.shape[1]), n_neurons, replace=False)]
    _, spikes = preprocess_spikes(ts, spikes)
    rasterplot(ax2, ts, spikes, color='k', linewidth=0.5)
    ax2.set_ylim(int(-n_neurons * 0.1), int(n_neurons * 1.5))

    delays = np.linspace(0.0, 1.0, 3)
    xs_tars = make_xs_tars(ts, xs, theta, delays)

    cmap = cm.get_cmap('viridis')
    ax3 = hide_spines(plt.subplot2grid((9, 1), (5, 0), sharex=ax1, rowspan=4))
    ax3.axhline(0, linestyle='--', linewidth=0.5, color='black', zorder=100)
    ax3.set_ylim(*ax3_ylim)
    ax3.set_yticks([])
    for i, delay in enumerate(delays):
        #color = cmap(0.9 * (1.0 - delays[i]))
        color = [cmap(0.1), cmap(0.5), cmap(0.9)][i]
        #color = ["#ce5c00", "#204a87", "#a40000"][i]
        ax3.plot(ts, decode_delay(xs_tars[i], As), color=color, zorder=i, label='${:0.2g}$'.format(delays[i]))
        ax3.plot(ts, xs_tars[i], color='white', linewidth=0.75, linestyle=(0, (1, 1)), zorder=2*(i + 1))
        ax3.plot(ts, xs_tars[i], color=color, linewidth=0.75, linestyle=(1, (1, 1)), zorder=2*(i + 1))
    ax3.text(1.0, 1.05, 'Decoded delays ${\\hat u}(t - \\theta\')$', 
         ha='right', va='top', fontsize=8, transform=ax3.transAxes)

    ax3.plot([0.0, 0.0], [0.0, 1.0], color='k', linewidth=1.5, clip_on=False, solid_capstyle="butt", zorder=10)
    ax3.text(0.025, 0.45, '$1$', 
         ha='left', va='center', fontsize=8)
    ax3.legend(
        loc='lower right',
        bbox_to_anchor=(1.025, -0.2),
        ncol=len(delays),
        fontsize=7,
        columnspacing=1.0,
        handlelength=0.9,
        handletextpad=0.5,
    )
    ax3.text(0.7, -0.01, 'Delay $\\theta\'/\\theta$', 
         ha='right', va='center', fontsize=7, transform=ax3.transAxes,)


#    cax = fig.add_axes([0.5, 0.0, 0.25, 0.05])
#    cb = plt.colorbar(
#       cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=0.0, vmax=1.0, clip=False),
#        cmap=cmap), cax=cax, orientation='horizontal')
#    cb.outline.set_visible(False)

    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.05, hspace=0.1)

    return fig

In [None]:
fig = plot_decoding_example(*res1, name="A", ax1_ylim=(0, 1), ax2_ylim=(-1.0, 1.25), ax3_ylim=(-0.3, 1.3), show_t_on=True)
fig.savefig('delay_example_a.pdf', bbox_inches='tight', transparent=True)

In [None]:
fig = plot_decoding_example(*res2, name="B", ax1_ylim=(0, 1), ax2_ylim=(-1.0, 1.0), ax3_ylim=(-1.4, 1.75))
fig.savefig('delay_example_b.pdf', bbox_inches='tight', transparent=True)

In [None]:
def min_radius(p_th=0.9, sigma=0.25, n=1000):
    max_x = 2
    rs = np.linspace(0, max_x, n)
    ps = np.exp(-rs**2/sigma**2)
    ps = np.cumsum(ps / np.sum(ps))

    valid_idcs = np.arange(n, dtype=int)[(1 - ps) > p_th]
    r = rs[np.max(valid_idcs)]
    return r
    
def plot_spatial_data(spatial_data):
    font = {
            'size'   : 8}

    matplotlib.rc('font', **font)

    fig, axs = plt.subplots(2, 1, figsize=(3.5, 3.5))

    x_golgi = spatial_data["golgi_locations"]
    x_granule = spatial_data["granule_locations"]
    ps = spatial_data["ps"]


    ax1 = axs[0]
    I = ax1.imshow(ps, extent=[1, 10000, 1, 100], interpolation='none', vmin=0.0, vmax=1.0)
    ax1.set_aspect('auto')
    ax1.set_xticks([1, 2000, 4000, 6000, 8000])
    ax1.set_xlabel('Granule cell index $i$', labelpad=0.25)
    ax1.set_ylabel('Golgi cell index $j$', labelpad=-1.0)

    
    rect = plt.Rectangle((7050, 65), 2700, 29, facecolor='white')
    ax1.add_artist(rect)

    cax = fig.add_axes([0.75, 0.9, 0.175, 0.02])
    cb = plt.colorbar(I, cax=cax, orientation='horizontal')
    cb.outline.set_visible(False)
    cax.text(0.5, -2.5, '$p_{ij}$', fontsize=8, ha='center', va='bottom', transform=cax.transAxes)

    ax2 = axs[1]

    cmap = cm.get_cmap('viridis')
    for j, p_th in enumerate(np.linspace(0.25, 0.9, 5)):
        r = min_radius(p_th)
        for i in range(x_golgi.shape[0]):
            circle = plt.Circle(x_golgi[i], r, fill=True, linewidth=1, color=cmap(p_th), zorder=100 * j + i)
            ax2.add_artist(circle)
    
    ax2.scatter(x_granule[:, 0], x_granule[:, 1], marker='o', color='black', s=2, label='Granule', zorder=1000)
    ax2.scatter(x_golgi[:, 0], x_golgi[:, 1], marker='+', color='#f57900', s=30, label='Golgi', zorder=2000)

    ax2.set_xlim(-0.5, 0.5)
    ax2.set_ylim(-0.2, 0.2)
    ax2.set_xlabel('Spatial location $x_1$', labelpad=-0.25)
    ax2.set_ylabel('Spatial location $x_2$', labelpad=-1.0)

    fig.tight_layout(h_pad=0.0)

    ax2.legend(loc='upper right', fontsize=8, prop={"style": "italic"}, facecolor="white", edgecolor="none", frameon=True, fancybox=False, framealpha=1.0, borderpad=0.3, handletextpad=0.2).set_zorder(10000)

    ax1.text(-0.12, 0.9, '$\\mathbf{A}$', 
         ha='right', va='bottom', fontsize=12, transform=ax1.transAxes)

    ax2.text(-0.12, 0.9, '$\\mathbf{B}$', 
         ha='right', va='bottom', fontsize=12, transform=ax2.transAxes)

    
    
    return fig

In [None]:
fig = plot_spatial_data(res3[-1])
fig.savefig('spatial_constraints.pdf', bbox_inches='tight', transparent=True)