In [None]:
from time import time
from datetime import datetime
from contextlib import contextmanager
from collections import namedtuple
import numpy as np
import matplotlib.pyplot as plt
import nest
import world
from world_populations import Planner, Cortex

from cerebellum import MF_number, define_models, \
        create_forward_cerebellum, create_inverse_cerebellum
import trajectories

In [None]:
Brain = namedtuple("Brain", "planner cortex forward inverse")

nest.Install("cerebmodule")
nest.Install("extracerebmodule")

trial_len = 300

In [None]:
def create_brain(prism):
    trajectories.save_file(prism, trial_len)

    define_models()
    cereb_inv = create_inverse_cerebellum()
    cereb_for = create_forward_cerebellum()
    # cereb_foo = create_forward_cerebellum()

    planner = Planner(MF_number, prism)
    cortex = Cortex(MF_number)

    planner.connect(cortex)

    # Forward model:
    # - motor input from the cortex (efference copy)
    # - sensory output to the cortex
    # - sensory error signal
    cortex.connect(cereb_for.mf)  # Efference copy

    fDCN = cereb_for.dcn
    conn_dict = {"rule": "fixed_indegree", "indegree": 1}
    nest.Connect(fDCN.plus.pop, cortex.pop, conn_dict, {'weight': 1.0})
    nest.Connect(fDCN.minus.pop, cortex.pop, conn_dict, {'weight': -1.0})

    # Inverse model;
    # - sensory input from planner
    # - motor output to world
    # - motor error signal
    planner.connect(cereb_inv.mf)  # Sensory input

    return cortex, cereb_for, cereb_inv

In [None]:
def create_cortex(prism=0):
    trajectories.save_file(0, trial_len)

    define_models()
    
    # MF_number = 720
    planner = Planner(MF_number, prism, baseline_rate=10.0, gain_rate=1.0)
    # planner = Planner(MF_number, prism)
    
    # cortex = Cortex(MF_number, rbf_sdev=10.0, baseline_rate=0.0)
    cortex = Cortex(MF_number, rbf_sdev=MF_number/64, baseline_rate=0.0, gain_rate=4.0)
    # cortex = Cortex(MF_number)

    planner.connect(cortex)
    # conn_dict = {'rule': 'one_to_one'}
    # syn_dict = {'weight': [i / MF_number for i in range(MF_number)]}
    # syn_dict = {'weight': [0.0] * (MF_number//2) + [1.0] * (MF_number//2)}
    # syn_dict = {'weight': [1.0 if i > 500 else 0.0 for i in range(MF_number)]}
    # nest.Connect(planner.pop, cortex.pop, conn_dict, syn_dict)
    return planner, cortex

In [None]:
nest.ResetKernel()

planner, cortex = create_cortex()
xs = []

# for i, prism in enumerate(np.arange(-5, 51, 5)):
# for i, prism in enumerate(np.arange(-5, 51, 20)):
for i, prism in enumerate((15, 40)):
    print("\rPrism: %3d" % prism, end="")
    planner.set_prism(prism)
    
    nest.Simulate(trial_len)
    x = cortex.integrate(trial_i=i)

print()
print("Rate:", cortex.joints[1].get_rate())
cortex.joints[1].plot_spikes()

In [None]:
fig, ax = plt.subplots(2, 3)

cortex.integrate(0)
ax[0, 0].plot(cortex.torques)
ax[0, 1].plot(cortex.vel)
ax[0, 2].plot(cortex.pos)


cortex.integrate(1)
ax[1, 0].plot(cortex.torques)
ax[1, 1].plot(cortex.vel)
ax[1, 2].plot(cortex.pos)

# fig.show()

In [None]:
# Get reference x
nest.ResetKernel()

planner, cortex = create_cortex()
planner.set_prism(0)
xs = []

N = 6

for i in range(N):
    nest.Simulate(trial_len)
    x = cortex.integrate(trial_i=i)
    xs.append(x)

x_0 = np.mean(xs[1:])


# nest.ResetKernel()
# planner, cortex = create_cortex()

planner.set_prism(10)
xs = []

for i in range(N):
    nest.Simulate(trial_len)
    x = cortex.integrate(trial_i=i+N)
    xs.append(x)

x_10 = np.mean(xs[1:])


get_error = world.get_error_function(x_0, x_10)

# cortex.joints[1].plot_spikes()
print("x_0", x_0)
print("x_10", x_10)

In [None]:
# Test open loop error
from random import shuffle
nest.ResetKernel()
planner, cortex = create_cortex()
errors = []
mins = []
maxs = []
planner_avgs = []

prism_values = list(np.arange(-5, 51, 5))
# prism_values = list(np.arange(30, 51, 5))
# shuffle(prism_values)

for j, prism in enumerate(prism_values):
# for j, prism in enumerate(np.arange(30, 51, 5)):
    print("Prism:", prism)
    planner.set_prism(prism)
    
    xs = []
    planner_rates = []
    
    n_trials = 5
    for i in range(n_trials):
        nest.Simulate(trial_len)

        x = cortex.integrate(trial_i = i + n_trials*j)
        cortex.joints[1].get_per_trial_rate()

        xs.append(x)
        planner_rates.append(planner.get_per_trial_rate())
        print("\rTrial:", i, end='')
        # print("%2.1f" % x, end=" ")

    planner_avgs.append(np.mean(planner_rates))
    errors.append(get_error(np.mean(xs)))
    mins.append(get_error(min(xs)))
    maxs.append(get_error(max(xs)))
    
    # print("\rError:", "%2.1f" % errors[-1], "x:", "%2.1f" % np.mean(xs[1:]))
    print("\rError:", "%2.1f" % errors[-1], "min: %2.1f" % mins[-1], "max: %2.1f" % maxs[-1])
    # print("Planner rate:", "%2.1f" % planner_avgs[-1])


In [None]:
# prism_values = np.arange(-5, 51, 5)
# prism_values = np.arange(30, 51, 5)

planner_avgs = 50 * np.array(planner_avgs) / max(planner_avgs)

plt.plot(prism_values, errors)
plt.plot(prism_values, mins)
plt.plot(prism_values, maxs)
plt.plot(prism_values, planner_avgs)

plt.show()

In [None]:
cortex.joints[1].plot_per_trial_rates()

In [None]:
def get_weights(pop1, pop2):
    conns = nest.GetConnections(pop1[::50], pop2[::50])
    weights = nest.GetStatus(conns, "weight")
    return weights

In [None]:
error_history = []

weights_for = []
weights_inv = []

In [None]:
nest.ResetKernel()
cortex, cereb_for, cereb_inv = create_brain(prism)
sensory_error = open_loop_error

In [None]:
for i in range(n_trials):
    if FORWARD:
        cereb_for.io.set_rate(sensory_error)
    if INVERSE:
        cereb_inv.io.set_rate(sensory_error, trial_i=i)

    print("Simulating")
    nest.Simulate(trial_len)
    print()
    print("Trial ", i+1)
    print()

    x_cortex = cortex.integrate(trial_i=i)

    if INVERSE:
        x_dcn = cereb_inv.dcn.integrate(trial_i=i)

        x_sum = x_cortex + x_dcn
    else:
        x_sum = x_cortex

    sensory_error = get_error(x_sum)
    error_history.append(sensory_error)
    print("Closed loop error %d:" % i, sensory_error)

    if FORWARD:
        print()
        print("Forward IO: %.1f" % cereb_for.io.get_per_trial_rate())
        print("Forward MF: %.1f" % cereb_for.mf.get_per_trial_rate())
        print("Forward GR: %.1f" % cereb_for.gr.get_per_trial_rate())
        print("Forward PC: %.1f" % cereb_for.pc.get_per_trial_rate())
        print("Forward DCN: %.1f" % cereb_for.dcn.get_per_trial_rate())

        weights = get_weights(cereb_for.gr.pop, cereb_for.pc.pop)
        weights_for.append(weights)
        print("Forward PFPC weights:", min(weights), "to", max(weights))

    if INVERSE:
        print()
        print("Inverse IO: %.1f" % cereb_inv.io.get_per_trial_rate())
        print("Inverse MF: %.1f" % cereb_inv.mf.get_per_trial_rate())
        print("Inverse GR: %.1f" % cereb_inv.gr.get_per_trial_rate())
        print("Inverse PC: %.1f" % cereb_inv.pc.get_per_trial_rate())
        print("Inverse DCN: %.1f" % cereb_inv.dcn.get_per_trial_rate())

        weights = get_weights(cereb_inv.gr.pop, cereb_inv.pc.pop)
        weights_inv.append(weights)
        print("Inverse PFPC weights:", min(weights), "to", max(weights))

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12,6))

if FORWARD and INVERSE:
    w_max = np.max([np.max(weights_for), np.max(weights_inv)])
elif FORWARD:
    w_max = np.max(weights_for)
elif INVERSE:
    w_max = np.max(weights_inv)

if FORWARD:
    axs[0].set_title("Forward PC-DCN weights")
    axs[0].matshow(np.transpose(weights_for), aspect='auto', vmin=0.0, vmax=30)

if INVERSE:
    axs[1].set_title("Inverse PC-DCN weights")
    axs[1].matshow(np.transpose(weights_inv), aspect='auto', vmin=0.0, vmax=30)

plt.tight_layout()
plt.show()

# fig.savefig("weights_40.pdf", bbox_inches='tight')
# fig.savefig("weights_40.png", bbox_inches='tight')

In [None]:
fig, axs = plt.subplots(5, figsize=(12,8))

if FORWARD:
    cereb_for.mf.plot_per_trial_rates('MF', axs[0])
    cereb_for.io.plot_per_trial_rates('IO', axs[1])
    cereb_for.pc.plot_per_trial_rates('PC', axs[2])
    cereb_for.dcn.plot_per_trial_rates('DCN', axs[3])

if INVERSE:
    cereb_inv.mf.plot_per_trial_rates('MF', axs[0])
    cereb_inv.io.plot_per_trial_rates('IO', axs[1])
    cereb_inv.pc.plot_per_trial_rates('PC', axs[2])
    cereb_inv.dcn.plot_per_trial_rates('DCN', axs[3])

axs[4].set_ylabel('Error')
axs[4].plot(error_history)
plt.show()

# fig.savefig("rates_40.pdf", bbox_inches='tight')
# fig.savefig("rates_40.png", bbox_inches='tight')

In [None]:
if FORWARD:
    fig, axs = plt.subplots(5, figsize=(12,8))
    fig.suptitle("Forward")

    cereb_for.mf.plot_spikes('f MF', axs[0])
    cereb_for.io.plot_spikes('f IO', axs[1])
    cereb_for.pc.plot_spikes('f PC', axs[2])
    cereb_for.dcn.plot_spikes('f DCN', axs[3])

    axs[4].set_ylabel('Error')
    axs[4].plot(error_history)
    plt.show()
    
    # fig.savefig("forward_spikes_40.pdf", bbox_inches='tight')
    # fig.savefig("forward_spikes_40.png", bbox_inches='tight')

In [None]:
if INVERSE:
    fig, axs = plt.subplots(5, figsize=(12, 8))
    fig.suptitle("Inverse")

    cereb_inv.mf.plot_spikes('i MF', axs[0])
    cereb_inv.io.plot_spikes('i IO', axs[1])
    cereb_inv.pc.plot_spikes('i PC', axs[2])
    cereb_inv.dcn.plot_spikes('i DCN', axs[3])

    axs[4].set_ylabel('Error')
    axs[4].plot(error_history)
    plt.show()
    
    # .savefig("inverse_spikes_40.pdf", bbox_inches='tight')
    # fig.savefig("inverse_spikes_40.png", bbox_inches='tight')