In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

from IPython.display import set_matplotlib_formats
set_matplotlib_formats('png', 'pdf')

import numpy as np
import matplotlib.pyplot as plt
import quantities as pq
from operator import itemgetter
from edog.tools import*
from edog.plot import*

filename = "params_Nr.yaml"
params = parse_parameters(filename)

nt, nr, dt, dr = itemgetter("nt", "nr", "dt", "dr")(params["grid"])
k_id, w_id, patch_diameter = itemgetter("k_id", "w_id", "patch_diameter")(params["stimulus"])
A_g, a_g, B_g, b_g = itemgetter("A", "a", "B", "b")(params["ganglion"])
w_rg, A_rg, a_rg = itemgetter("w", "A", "a")(params["relay"]["Krg"])
w_rig, A_rig, a_rig = itemgetter("w", "A", "a")(params["relay"]["Krig"])
w_rc_ex, A_rc_ex, a_rc_ex = itemgetter("w", "A", "a")(params["relay"]["Krc_ex"])
w_rc_in, A_rc_in, a_rc_in = itemgetter("w", "A", "a")(params["relay"]["Krc_in"])

w_rc_mix = itemgetter("w")(params["relay"]["Krc_mix"])
A_rc_mix_in, a_rc_mix_in = itemgetter("A", "a")(params["relay"]["Krc_mix"]["Krc_in"])
A_rc_mix_ex, a_rc_mix_ex = itemgetter("A", "a")(params["relay"]["Krc_mix"]["Krc_ex"])

wavenumber_tuning = {"fb_ex": {}, "fb_in": {}, "fb_mix": {}}

## Feedback excitation

In [2]:
for d in patch_diameter:
    tuning_curve = np.zeros([len(nr), len(w_rc_ex), len(k_id)])
    
    for j, n in enumerate(nr):
        for i, w in enumerate(w_rc_ex):
            network = create_spatial_network(nt=nt, nr=n, dt=dt, dr=dr,
                                     A_g=A_g, a_g=a_g, B_g=B_g, b_g=b_g,
                                     w_rg=w_rg, A_rg=A_rg, a_rg=a_rg,
                                     w_rig=w_rig, A_rig=A_rig, a_rig=a_rig,
                                     w_rc_ex=w, A_rc_ex=A_rc_ex, a_rc_ex=a_rc_ex)

            angular_freq = network.integrator.temporal_angular_freqs[int(w_id)]
            wavenumber = network.integrator.spatial_angular_freqs[k_id.astype(int)]
            spatiotemporal_tuning = spatiotemporal_wavenumber_tuning(network=network,
                                                                     angular_freq=angular_freq,
                                                                     wavenumber=wavenumber,
                                                                     patch_diameter=d)

            tuning_curve[j, i, :] = spatiotemporal_tuning[0, :]

    wavenumber_tuning["fb_ex"][float(d.magnitude)] = tuning_curve



## Feedback inhibition

In [3]:
for d in patch_diameter:
    tuning_curve = np.zeros([len(nr), len(w_rc_in), len(k_id)])   
    
    for j, n in enumerate(nr):
        for i, w in enumerate(w_rc_in):
            network = create_spatial_network(nt=nt, nr=n, dt=dt, dr=dr,
                                     A_g=A_g, a_g=a_g, B_g=B_g, b_g=b_g,
                                     w_rg=w_rg, A_rg=A_rg, a_rg=a_rg,
                                     w_rig=w_rig, A_rig=A_rig, a_rig=a_rig,
                                     w_rc_in=w, A_rc_in=A_rc_in, a_rc_in=a_rc_in)

            angular_freq = network.integrator.temporal_angular_freqs[int(w_id)]
            wavenumber = network.integrator.spatial_angular_freqs[k_id.astype(int)]
            spatiotemporal_tuning = spatiotemporal_wavenumber_tuning(network=network,
                                                                     angular_freq=angular_freq,
                                                                     wavenumber=wavenumber,
                                                                     patch_diameter=d)

            tuning_curve[j, i, :] = spatiotemporal_tuning[0, :]
        
    wavenumber_tuning["fb_in"][float(d.magnitude)] = tuning_curve



## Mixed excitatory and inhibitory feedback

In [4]:
for d in patch_diameter:
    tuning_curve = np.zeros([len(nr), len(w_rc_mix), len(k_id)])
    
    for j, n in enumerate(nr):
        for i, w in enumerate(w_rc_mix):
            network = create_spatial_network(nt=nt, nr=n, dt=dt, dr=dr,
                                     A_g=A_g, a_g=a_g, B_g=B_g, b_g=b_g,
                                     w_rg=w_rg, A_rg=A_rg, a_rg=a_rg,
                                     w_rig=w_rig, A_rig=A_rig, a_rig=a_rig,
                                     w_rc_in=w, A_rc_in=A_rc_mix_in, a_rc_in=a_rc_mix_in,
                                     w_rc_ex=w, A_rc_ex=A_rc_mix_ex, a_rc_ex=a_rc_mix_ex)

            angular_freq = network.integrator.temporal_angular_freqs[int(w_id)]
            wavenumber = network.integrator.spatial_angular_freqs[k_id.astype(int)]
            spatiotemporal_tuning = spatiotemporal_wavenumber_tuning(network=network,
                                                                     angular_freq=angular_freq,
                                                                     wavenumber=wavenumber,
                                                                     patch_diameter=d)

            tuning_curve[j, i, :] = spatiotemporal_tuning[0, :]


    wavenumber_tuning["fb_mix"][float(d.magnitude)] = tuning_curve



-------------
## Convergence plots

In [5]:
def map_wavenumber(tuning, j, nr):
    k_id = 4
    map_factors = (k_id * 2**(nr - nr.min())).astype(int)
    res = np.zeros(len(nr))
    
    for i, k in enumerate(map_factors):
        res[i] = tuning[i, j, k]
        
    return res


titles = ["Excitatory feedback", "Inhibitory feedback", "Mixed feedback"]
labels = [r"$w^\mathrm{ex}_{\mathrm{RCR}}$", r"$|w^\mathrm{in}_{\mathrm{RCR}}|$", 
          r"$w^\mathrm{mix}_{\mathrm{RCR}}$"]

fig, axarr = plt.subplots(3, 3, figsize=(9, 6), sharex="all")
plt.figtext(0.5, 0.99, titles[0], ha='center', va='center', fontsize=16)
plt.figtext(0.5, 0.68, titles[1], ha='center', va='center', fontsize=16)
plt.figtext(0.5, 0.36, titles[2], ha='center', va='center', fontsize=16)
plt.subplots_adjust(hspace=2)


for i, (key, w_rc) in enumerate(zip(wavenumber_tuning.keys(), [w_rc_ex, w_rc_in, w_rc_mix])):
    ax = axarr[i, :]
    
    for j, w in enumerate(w_rc):
        tuning = map_wavenumber(wavenumber_tuning[key][1.5], j, nr)
        ax[j].set_title(labels[i]+"={}".format(w))
        ax[j].plot(2**nr, tuning, "-o")
        ax[j].set_ylabel("d=1.5 deg", color="C0")
        ax[j].set_xscale('log', basex=2)
        
        
        tuning = map_wavenumber(wavenumber_tuning[key][10], j,  nr)
        ax_twin = ax[j].twinx()
        ax_twin.plot(2**nr, tuning, "-oC1")
        ax_twin.set_ylabel("d=10 deg", color="C1")
        ax_twin.set_xscale('log', basex=2)
        #ax_twin.ticklabel_format(useOffset=False)
        

    
for ax in axarr.flat[-3:]:
    ax.set_xlabel(r"Number of spatial points, $N$")
    ax.set_xticks(2**nr)

fig.tight_layout()

<IPython.core.display.Javascript object>