In [19]:
import mpl_interactions.ipyplot as iplt
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors
from scipy.stats import norm, multivariate_normal
from matplotlib import gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import itertools as it
import ipywidgets as widgets
from scipy import signal, fft
import warnings
warnings.filterwarnings("ignore")
%matplotlib ipympl

In [3]:
def gaussian_dist_2d(loc, scale):
    mean = np.zeros(2) + loc
    cov = np.eye(2) * scale
    dist_2d = multivariate_normal(mean, cov)
    return dist_2d

def plot_Z(Z, extent=[-5, 5, -5, 5]):
    f = plt.figure(figsize=(5,5))
    ax = plt.subplot(111)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    im = ax.imshow(Z, aspect="equal", cmap="RdBu_r", norm=divnorm, origin="lower", extent=extent)
    plt.colorbar(im, cax=cax)

def plot_Z_multiple(Z, ax, extent=[-5, 5, -5, 5]):
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    im = ax.imshow(Z,aspect="equal", cmap="RdBu_r", norm=divnorm, origin="lower", extent=extent)
    plt.colorbar(im, cax=cax)

In [4]:
def gaussian_2d(position_x, position_y, scale, amp):
    mean = np.zeros(2) + np.array([position_x, position_y])
    cov = np.eye(2) * scale
    dist_2d = multivariate_normal(mean, cov)
    dist_2d = dist_2d.pdf(d).reshape(resolution, resolution)
    dist_2d = dist_2d * (amp/dist_2d.max())
    return dist_2d


def gauss_sum(prox_pos_x, prox_pos_y, prox_scale, prox_amp, dist_pos_x, dist_pos_y, dist_scale, dist_amp):
    prox = gaussian_2d(prox_pos_x, prox_pos_y, prox_scale, prox_amp)
    dist = gaussian_2d(dist_pos_x, dist_pos_y, dist_scale, dist_amp)
    dist = dist
    return prox + (-dist)


def time_course(t, offset, speed):
    tc = norm.pdf(t, offset, speed)
    tc = tc * (1/tc.max())
    return tc

In [5]:
def wv_transform(data, sfreq, min_freq, max_freq, num_frex):
    # frequencies
    frex = np.linspace(min_freq, max_freq, num=num_frex)

    # wavelet cycles
    range_cycles = [3, 10]
    cycvec = np.logspace(np.log10(range_cycles[0]), np.log10(range_cycles[-1]), num=num_frex) / (2*np.pi*frex)

    # wavelet params
    wavtime = np.arange(-1, 1+1/sfreq, 1/sfreq)
    half_wave = (wavtime.shape[0]-1)/2

    # fft params
    nWave = wavtime.shape[0]
    nData = data.shape[0]
    nConv = nWave + nData - 1

    tf_data = np.zeros((frex.shape[0], nData))

    for f_i, f in enumerate(frex):
        # wavelet creation
        p1 = np.exp(2*1j*np.pi*f*wavtime)
        p2 = np.exp(-(wavtime**2)/(2*cycvec[f_i]**2))
        wavelet = p1 * p2
        waveletX = fft.fft(wavelet, nConv)
        waveletX = waveletX / np.max(waveletX)
        data_ = fft.fft(data, nConv)
        # convolution
        data_conv = fft.ifft(waveletX * data_)
        data_conv = data_conv[int(half_wave):int(data_conv.shape[0]-half_wave)]
        # compute power
        tf_data[f_i,:] = np.abs(data_conv) ** 2
    return tf_data

In [46]:
divnorm = colors.TwoSlopeNorm(vmin=-10., vcenter=0, vmax=10)
resolution = 100
space = np.linspace(-5,5,resolution)
X, Y = np.meshgrid(np.linspace(-5,5,resolution), np.linspace(-5,5,resolution))
d = np.dstack([X,Y])
t = np.linspace(-200,200, num=200)

In [64]:

gs = gridspec.GridSpec(1, 2, wspace=0.2, hspace=0.2)
figure = plt.figure(figsize=[9, 4])

ax = figure.add_subplot(gs[0, 0], label="0")
controls = iplt.imshow(
    gauss_sum,
    prox_pos_x=(-5,5), 
    prox_pos_y=(-5,5), 
    prox_scale=(0.1,25), 
    prox_amp=(1,10), 
    dist_pos_x=(-5,5), 
    dist_pos_y=(-5,5), 
    dist_scale=(0.1,25), 
    dist_amp=(1,10),
    aspect="equal", 
    cmap="RdBu_r", 
    norm=divnorm, 
    origin="lower", 
    extent=[-5, 5, -5, 5], 
    ax=ax, 
    interpolation=None
)
ax = figure.add_subplot(gs[0, 1], label="2")
controls_prox = iplt.plot(
    t,
    time_course,
    offset=(-200,200),
    speed=(0.1,50),
    ax=ax,
    label="Proximal drive",
    c="red"
)
controls_dist = iplt.plot(
    t,
    time_course,
    offset=(-200,200),
    speed=(0.1,50),
    ax=ax,
    label="Distal drive",
    c="blue"
)
_ = plt.legend(loc=8, fontsize="x-small")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

VBox(children=(HBox(children=(IntSlider(value=0, description='prox_pos_x', max=49, readout=False), Label(value…

VBox(children=(HBox(children=(IntSlider(value=0, description='offset', max=49, readout=False), Label(value='-2…

VBox(children=(HBox(children=(IntSlider(value=0, description='offset', max=49, readout=False), Label(value='-2…

In [91]:
prox = gaussian_2d(controls.params['prox_pos_x'], controls.params['prox_pos_y'], controls.params['prox_scale'], controls.params['prox_amp'])
dist = gaussian_2d(controls.params['dist_pos_x'], controls.params['dist_pos_y'], controls.params['dist_scale'], controls.params['dist_amp'])
prox_tc = time_course(t, controls_prox.params['offset'], controls_prox.params['speed'])
dist_tc = time_course(t, controls_dist.params['offset'], controls_dist.params['speed'])
prox_all = np.array([prox * i for i in prox_tc])
dist_all = np.array([dist * i for i in dist_tc])
sum_all = prox_all + (-dist_all)
res = np.arange(10, 100, 10)
gs_size = res.shape[0]

gs = gridspec.GridSpec(gs_size, gs_size, wspace=0.02, hspace=0.2)
fig = plt.figure(figsize=[15, 15])
fontdict = {
    'fontsize': 8,
    'fontweight' : 1
}
gs_norm = colors.TwoSlopeNorm(vmin=-1., vcenter=0, vmax=1)
coord = np.array(list(it.product(np.flip(res), res, repeat=1)))
for (ix, (x, y)) in enumerate(coord):
    ax = fig.add_subplot(gs[ix], label=str((x, y)))
    
    sig = sum_all[:,x, y]
    tf = wv_transform(sig, 500, 0, 60, 100)
    tf_max = np.where(tf == tf.max())[0][0]
    ax.imshow(
        tf, 
        cmap="RdBu_r", 
        aspect=tf.shape[1]/tf.shape[0], 
        origin="lower", 
        norm=gs_norm, 
        interpolation="gaussian"
    )
    ax.set_title(
        "x:{} y:{}, mf:{}".format(np.round(space[y],0), np.round(space[x], 0), tf_max), 
        fontdict=fontdict, pad=-1
    )
    ax.plot(sig*10+tf.shape[0]-10*sum_all.max()-10, lw=.5, c="black", alpha=0.4)
    
    ax.axis('off')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [89]:
summary = np.zeros((50, 50))
for (x, y) in list(it.product(np.arange(0,100,2), np.arange(0,100,2), repeat=1)):
    sig = sum_all[:,x, y]
    tf = wv_transform(sig, 500, 0, 60, 100)
    tf_max = np.where(tf == tf.max())[0][0]
    summary[int(x/2),int(y/2)] = tf_max

f = plt.figure(figsize=(5,5))
ax = plt.subplot(111)
divnorm = colors.TwoSlopeNorm(vmin=-0.001, vcenter=0, vmax=60)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
im = ax.imshow(summary,aspect="equal", norm=divnorm, cmap="RdBu_r", origin="lower",extent=[-5, 5, -5, 5])
plt.colorbar(im, cax=cax)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.colorbar.Colorbar at 0x7f4e6e8ac160>

In [90]:
f = plt.figure(figsize=(5,5))
ax = plt.subplot(111)
divnorm = colors.TwoSlopeNorm(vmin=-0.001, vcenter=0, vmax=60)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
im = ax.imshow(summary,aspect="equal", norm=divnorm, cmap="RdBu_r", origin="lower",extent=[-5, 5, -5, 5])
plt.colorbar(im, cax=cax)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.colorbar.Colorbar at 0x7f4e6e752250>