In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

In [3]:
import json
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import cm


import numpy as np
from loguru import logger
logger.remove()

from pim.models import rate
from pim.models.weights import cpu1a_pontine_output, cpu1b_pontine_output, motor_output, pontine_output
from pim.models.weights.shift import motor_output_theoretical, W_CPU4_CPU1a, W_CPU4_CPU1b
from pim.models.constants import *
from pim.network import Network, RecurrentForwardNetwork, InputLayer, Layer, FunctionLayer, Output, WeightedSynapse

In [4]:
N_COLUMNS = 8
x = np.linspace(0, np.pi*2, N_COLUMNS, endpoint = False)

A = 0.1
d = -np.pi/2

head = np.maximum(0,np.cos(x + d))
home = np.maximum(0,A*np.cos(x))

plt.plot(head)
plt.plot(home)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7f692ae88640>]

In [5]:
beta = 0.0
noise = 0.0
pfn_weight_factor = 1

def getNetwork(home):
    return RecurrentForwardNetwork({
        "flow": InputLayer(initial=np.zeros(2)),
        "heading": InputLayer(),
        "TB1": FunctionLayer(
            inputs = ["heading"],
            function = lambda i: np.cos(x + i[0]),
        ),
        "TN1": FunctionLayer(
            inputs = ["flow"],
            function = lambda i: np.array([0.,0.]),
            initial = np.zeros(N_TN1),
        ),
        "TN2": FunctionLayer(
            inputs = ["flow"],
            function = lambda i: np.array([0.4,0.4]),
            initial = np.zeros(N_TN2),
        ),
        "CPU4": rate.MemorylessCPU4Layer(
            "TB1", "TN1", "TN2",
            rate.W_TN_CPU4,
            rate.W_TB1_CPU4,
            gain = pfn_weight_factor,
            slope = cpu4_slope_tuned,
            bias = cpu4_bias_tuned,
            noise = noise,
            background_activity = beta,
        ),
        "memory": FunctionLayer(
            inputs = ["CPU4"],
            function = lambda i: i[0] * np.hstack([home,home]),
        ),
        "Pontine": FunctionLayer(
            inputs = ["memory"],
            function = pontine_output(noise),
            initial = np.zeros(N_Pontine)
        ),
        "theory": FunctionLayer(
            inputs = ["memory", "CPU4", "TB1", "Pontine"],
            function = motor_output_theoretical(noise)
        ),
        "CPU1a": FunctionLayer(
            inputs = [WeightedSynapse("CPU4", W_CPU4_CPU1a), "memory", "Pontine"],
            function = cpu1a_pontine_output(
                noise,
                cpu1_pontine_slope_tuned,
                cpu1_pontine_bias_tuned,
            ),
            initial = np.zeros(N_CPU1A),
        ),
        "CPU1b": FunctionLayer(
            inputs = [WeightedSynapse("CPU4", W_CPU4_CPU1b), "memory", "Pontine"],
            function = cpu1b_pontine_output(
                noise,
                cpu1_pontine_slope_tuned,
                cpu1_pontine_bias_tuned,
            ),
            initial = np.zeros(N_CPU1B),
        ),
        "motor": FunctionLayer(
            inputs = ["CPU1a", "CPU1b"],
            function = motor_output(noise),
        )
    })

In [6]:
amplitudes = np.arange(0,1,0.01)
deltas = np.arange(-np.pi,np.pi,0.01)
results = []

for a in amplitudes:
    home = np.maximum(0,a*np.cos(x))
    network = getNetwork(home)
    motors = []
    for delta in deltas:
        network.layers["flow"].set(0)
        network.layers["heading"].set(np.array([delta]))
        network.step(1)
        motor = network.output("motor")
        motors.append(-motor)
    results.append(motors)

In [7]:
plt.figure(figsize=(10,10))
plt.plot(deltas,motors)
plt.ylim([-2,2])

<IPython.core.display.Javascript object>

(-2.0, 2.0)

In [8]:
fig = plt.figure(figsize=(10,10))
# ax = fig.add_subplot(111,projection="3d")
ax = plt.axes(projection='3d')

(x,y) = np.meshgrid(deltas,amplitudes)
z = np.array(results)

ax.plot_surface(x,y,z,cmap=cm.coolwarm)
# plt.show()

# def animate(frame):
#     ax.view_init(30, frame/4)
#     plt.pause(.001)
#     return fig

# anim = animation.FuncAnimation(fig, animate, frames=600, interval=10)
# HTML(anim.to_html5_video())

<IPython.core.display.Javascript object>

<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x7f6928bd2470>