In [None]:
from time import time
from datetime import datetime
from contextlib import contextmanager
from collections import namedtuple
import numpy as np
import matplotlib.pyplot as plt
import nest
import world
from world_populations import Planner, Cortex

from cerebellum import MF_number, define_models, \
        create_forward_cerebellum, create_inverse_cerebellum
import trajectories

In [None]:
Brain = namedtuple("Brain", "planner cortex forward inverse")

nest.Install("cerebmodule")
nest.Install("extracerebmodule")

trial_len = 300

In [None]:
def create_brain(prism):
    trajectories.save_file(prism, trial_len)

    define_models()
    cereb_inv = create_inverse_cerebellum()
    cereb_for = create_forward_cerebellum()
    # cereb_foo = create_forward_cerebellum()

    planner = Planner(MF_number, prism)
    cortex = Cortex(MF_number)

    planner.connect(cortex)

    # Forward model:
    # - motor input from the cortex (efference copy)
    # - sensory output to the cortex
    # - sensory error signal
    cortex.connect(cereb_for.mf)  # Efference copy

    fDCN = cereb_for.dcn
    conn_dict = {"rule": "fixed_indegree", "indegree": 1}
    nest.Connect(fDCN.plus.pop, cortex.pop, conn_dict, {'weight': 1.0})
    nest.Connect(fDCN.minus.pop, cortex.pop, conn_dict, {'weight': -1.0})

    # Inverse model;
    # - sensory input from planner
    # - motor output to world
    # - motor error signal
    planner.connect(cereb_inv.mf)  # Sensory input

    return cortex, cereb_for, cereb_inv

In [None]:
FORWARD = False
INVERSE = True

prism = 20.0
n_trials = 10

In [None]:
# Get reference x
nest.ResetKernel()
cortex, _, _ = create_brain(0.0)
xs = []

for i in range(6):
    nest.Simulate(trial_len)
    x = cortex.integrate(trial_i=i)
    if i >= 1:
         xs.append(x)
    # xs.append(x)

x_0 = np.mean(xs)

nest.ResetKernel()
cortex, _, _ = create_brain(10.0)
xs = []

for i in range(6):
    nest.Simulate(trial_len)
    x = cortex.integrate(trial_i=i)
    if i >= 1:
         xs.append(x)
    # xs.append(x)

x_10 = np.mean(xs)

get_error = world.get_error_function(x_0, x_10)

In [None]:
# Get open loop error
nest.ResetKernel()
cortex, _, _ = create_brain(prism)
xs = []

for i in range(6):
    nest.Simulate(trial_len)

    x = cortex.integrate(trial_i=i)
    if i >= 1:
        xs.append(x)

open_loop_error = get_error(np.mean(xs))

print("Open loop error:", open_loop_error)

In [None]:
def get_weights(pop1, pop2):
    conns = nest.GetConnections(pop1[::50], pop2[::50])
    weights = nest.GetStatus(conns, "weight")
    return weights

In [None]:
error_history = []

weights_for = []
weights_inv = []

In [None]:
nest.ResetKernel()
cortex, cereb_for, cereb_inv = create_brain(prism)
sensory_error = open_loop_error

In [None]:
for i in range(n_trials):
    if FORWARD:
        cereb_for.io.set_rate(sensory_error)
    if INVERSE:
        cereb_inv.io.set_rate(sensory_error, trial_i=i)

    print("Simulating")
    nest.Simulate(trial_len)
    print()
    print("Trial ", i+1)
    print()

    x_cortex = cortex.integrate(trial_i=i)

    if INVERSE:
        x_dcn = cereb_inv.dcn.integrate(trial_i=i)

        x_sum = x_cortex + x_dcn
    else:
        x_sum = x_cortex

    sensory_error = get_error(x_sum)
    error_history.append(sensory_error)
    print("Closed loop error %d:" % i, sensory_error)

    if FORWARD:
        print()
        print("Forward IO: %.1f" % cereb_for.io.get_per_trial_rate())
        print("Forward MF: %.1f" % cereb_for.mf.get_per_trial_rate())
        print("Forward GR: %.1f" % cereb_for.gr.get_per_trial_rate())
        print("Forward PC: %.1f" % cereb_for.pc.get_per_trial_rate())
        print("Forward DCN: %.1f" % cereb_for.dcn.get_per_trial_rate())

        weights = get_weights(cereb_for.gr.pop, cereb_for.pc.pop)
        weights_for.append(weights)
        print("Forward PFPC weights:", min(weights), "to", max(weights))

    if INVERSE:
        print()
        print("Inverse IO: %.1f" % cereb_inv.io.get_per_trial_rate())
        print("Inverse MF: %.1f" % cereb_inv.mf.get_per_trial_rate())
        print("Inverse GR: %.1f" % cereb_inv.gr.get_per_trial_rate())
        print("Inverse PC: %.1f" % cereb_inv.pc.get_per_trial_rate())
        print("Inverse DCN: %.1f" % cereb_inv.dcn.get_per_trial_rate())

        weights = get_weights(cereb_inv.gr.pop, cereb_inv.pc.pop)
        weights_inv.append(weights)
        print("Inverse PFPC weights:", min(weights), "to", max(weights))

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12,6))

if FORWARD and INVERSE:
    w_max = np.max([np.max(weights_for), np.max(weights_inv)])
elif FORWARD:
    w_max = np.max(weights_for)
elif INVERSE:
    w_max = np.max(weights_inv)

if FORWARD:
    axs[0].set_title("Forward PC-DCN weights")
    axs[0].matshow(np.transpose(weights_for), aspect='auto', vmin=0.0, vmax=30)

if INVERSE:
    axs[1].set_title("Inverse PC-DCN weights")
    axs[1].matshow(np.transpose(weights_inv), aspect='auto', vmin=0.0, vmax=30)

plt.tight_layout()
plt.show()

fig.savefig("weights_40.pdf", bbox_inches='tight')
fig.savefig("weights_40.png", bbox_inches='tight')

In [None]:
fig, axs = plt.subplots(5, figsize=(12,8))

if FORWARD:
    cereb_for.mf.plot_per_trial_rates('MF', axs[0])
    cereb_for.io.plot_per_trial_rates('IO', axs[1])
    cereb_for.pc.plot_per_trial_rates('PC', axs[2])
    cereb_for.dcn.plot_per_trial_rates('DCN', axs[3])

if INVERSE:
    cereb_inv.mf.plot_per_trial_rates('MF', axs[0])
    cereb_inv.io.plot_per_trial_rates('IO', axs[1])
    cereb_inv.pc.plot_per_trial_rates('PC', axs[2])
    cereb_inv.dcn.plot_per_trial_rates('DCN', axs[3])

axs[4].set_ylabel('Error')
axs[4].plot(error_history)
plt.show()

fig.savefig("rates_40.pdf", bbox_inches='tight')
fig.savefig("rates_40.png", bbox_inches='tight')

In [None]:
if FORWARD:
    fig, axs = plt.subplots(5, figsize=(12,8))
    fig.suptitle("Forward")

    cereb_for.mf.plot_spikes('f MF', axs[0])
    cereb_for.io.plot_spikes('f IO', axs[1])
    cereb_for.pc.plot_spikes('f PC', axs[2])
    cereb_for.dcn.plot_spikes('f DCN', axs[3])

    axs[4].set_ylabel('Error')
    axs[4].plot(error_history)
    plt.show()
    
    fig.savefig("forward_spikes_40.pdf", bbox_inches='tight')
    fig.savefig("forward_spikes_40.png", bbox_inches='tight')

In [None]:
if INVERSE:
    fig, axs = plt.subplots(5, figsize=(12, 8))
    fig.suptitle("Inverse")

    cereb_inv.mf.plot_spikes('i MF', axs[0])
    cereb_inv.io.plot_spikes('i IO', axs[1])
    cereb_inv.pc.plot_spikes('i PC', axs[2])
    cereb_inv.dcn.plot_spikes('i DCN', axs[3])

    axs[4].set_ylabel('Error')
    axs[4].plot(error_history)
    plt.show()
    
    fig.savefig("inverse_spikes_40.pdf", bbox_inches='tight')
    fig.savefig("inverse_spikes_40.png", bbox_inches='tight')