In [24]:
import torch
import matplotlib.pyplot as plt
from utils import *
from ControlledLayer import ControlledLayer

In [25]:
controller_dim = 1
layer = ControlledLayer(3, 1, controller_dim=controller_dim, mode="rate", leak=0.)

def control_neuron_rate(
    input_rate,
    target_rate,
    control_target_rate,
    timesteps,
    C_precision=0.01,
):
    
    controller_effect = []

    # noise = np.random.randn(400)*neuron_noise
    # idx_noise = 0
    input_rate = torch.tensor(input_rate).float()
    layer.reset()

    with torch.no_grad():
        control_input = torch.zeros(controller_dim)
        for _ in range(timesteps * 2 - 1):
            output_rate = layer(input_rate, torch.tensor([control_input]).float()).numpy()
        outputs = [output_rate]

        while abs(output_rate - target_rate) > C_precision:
            output_rate = layer(input_rate, torch.tensor([control_input]).float()).numpy()
            control_input = control_input + 0.1 * (control_target_rate - output_rate)
            
            outputs.append(output_rate)
            controller_effect.append(output_rate - outputs[0])
            # We can also simulate neurons with their own dynamics (but it takes longer)
            # output, v_0 = runNeuron_rate(new_input*np.ones(timesteps), init_v=v_0)
            # output_rate = output[-1]
            # outputs.extend(output)
        return outputs, controller_effect

In [26]:
v_th = 1
target_rates = [-1, 2]
learn_rate = 0.01
regularizer_rate = 0.0  # 0.00005#
C_precision = 0.01
data_noise = 0.1
timesteps = 2
initial_steps_plot = 5
neuron_noise = 0.

plot_path = "./plots/"
plot = True

total_points = 6000
train_points = 5000
w_0 = np.array([-5., 5., -5.]) 
X, y = learnAssociationTask(total_points, data_noise=data_noise)
dynamic_plot_idxs = [1, 501, 2001, 9001]


In [27]:
layer.ff.weight.data = torch.tensor(w_0).float()
w_evol = []
control_evol = []
FF_output_evol = []
time_to_targ_evol = []
DW_DH_list = []
DW_STDP_list = []
list_output_dynamics = []
list_controller_dynamics = []

count_1 = 0
next_dyn_plot_idx = 0


# loop over datapoints
for idx in range(len(y)):
    print(idx)
    target_rate = y[idx]  # output label
    control_target_rate = target_rates[target_rate]

    # FORWARD, with controller controlling
    output, input_C = control_neuron_rate(X[idx, :], target_rate, control_target_rate, timesteps=timesteps)
    R = len(output) - 1  # number of timesteps the controller had to act?

    if R > 0:  # avoids learning if the feedback is already good
        presynaptic_rates = np.vstack([sigmoid(X[idx, :])] * R)
        Dw_DH = update_weights_rates(np.array(output[:-1])[:, 0], presynaptic_rates)
        Dw_STDP = update_weights_poisson(output[:-1], presynaptic_rates)
        print("Weight before", layer.ff.weight.data)
        layer.ff.weight.data += torch.tensor(Dw_STDP).float() - regularizer_rate * layer.ff.weight.data
        print("Update", torch.tensor(Dw_STDP).float())
        print("Weight after", layer.ff.weight.data)

        DW_DH_list.append(Dw_DH)
        DW_STDP_list.append(Dw_STDP)

    count_1 += y[idx]
    if count_1 == dynamic_plot_idxs[next_dyn_plot_idx]:
        list_output_dynamics.append(output[1:])
        list_controller_dynamics.append(input_C)
        next_dyn_plot_idx = (next_dyn_plot_idx + 1) % len(dynamic_plot_idxs)

    w_evol.append(layer.ff.weight.data.numpy())
    FF_output_evol.append(output[0])
    time_to_targ_evol.append(len(input_C))
    if len(input_C) > 1:
        control_evol.append(input_C[-1])
    else:
        control_evol.append(0)

print("Output = 1 was shown " + str(count_1) + " times")

if plot:
    plt.show()

    # ax = plt.axes()
    # ax.set_facecolor("white")
    plt.plot(w_evol)
    plt.legend(["$w_{A->C}$", "$w_{B->C}$", "$w_{bias}$"])
    plt.xlabel("Example")
    plt.ylabel("Weights")
    plt.savefig(
        plot_path + "WeightEvolution.eps", bbox_inches="tight", format="eps"
    )

    plt.show()

    # Benni: Fig 2, panel 1. Normalized to have the maximum at 1
    max_FF = max(FF_output_evol)
    max_C = max(control_evol)
    max_T = max(time_to_targ_evol)
    control_evol_norm = [c / max_C for (c, s) in zip(control_evol, y) if s == 1]
    FF_output_evol_norm = [
        (1 - f) ** 2 for (f, s) in zip(FF_output_evol, y) if s == 1
    ]
    time_to_output_evol_norm = [
        t / max_T for (t, s) in zip(time_to_targ_evol, y) if s == 1
    ]
    plt.plot(control_evol_norm, label="Feedback")
    plt.plot(FF_output_evol_norm, label="MSE Error")
    plt.plot(time_to_output_evol_norm, label="Time to target output")
    plt.xlabel("Example")
    plt.ylabel("Loss")
    plt.title("Loss functions for class 1")
    plt.legend()
    plt.savefig(
        plot_path + "OptimizationFunctions.eps", bbox_inches="tight", format="eps"
    )
    plt.show()

    # Benni: Fig 4 A
    for i in range(len(list_output_dynamics)):
        str_dynamics = "Dynamics at example " + str(dynamic_plot_idxs[i] - 1)
        dynamics = list_output_dynamics[i]
        # init_dynamics, v = runNeuron_rate([0]*initial_steps_plot)
        init_dynamics = [
            dynamics[0] + n * neuron_noise / 2
            for n in np.random.randn(initial_steps_plot)
        ]
        list_to_plot = init_dynamics + dynamics
        plt.plot(list_to_plot, label=str_dynamics)
    # plt.plot(output, label="Dynamics at last example")
    plt.ylim([0.0, 1])
    plt.xlabel("Time")
    plt.ylabel("Output rate")
    plt.legend()
    plt.savefig(
        plot_path + "TemporalDynamics.eps", bbox_inches="tight", format="eps"
    )
    plt.show()

    for i in range(len(list_output_dynamics)):
        str_dynamics = "Feedback at example " + str(dynamic_plot_idxs[i] - 1)
        init_ctr = list_controller_dynamics[i]
        init_control = [
            init_ctr[0] + n * neuron_noise / 2
            for n in np.random.randn(initial_steps_plot)
        ]
        list_to_plot = init_control + list_controller_dynamics[i]
        plt.plot(list_to_plot, label=str_dynamics)

    # plt.plot(input_C, label="Feedback at last example")
    plt.xlabel("Time")
    plt.ylabel("Feedback strength")
    plt.legend()
    plt.show()
    plt.savefig(
        plot_path + "TemporalFeedback.eps", bbox_inches="tight", format="eps"
    )
    # STDP vs dendritic error update
    L_errors = int(
        len(DW_DH_list) / 2
    )  # when plotting errors after learning we get noise
    plt.scatter(DW_DH_list[:L_errors], DW_STDP_list[:L_errors])
    plt.xlabel("STDP weight update")
    plt.ylabel("Dendritic error update")
    plt.savefig(plot_path + "STDP_vs_DH.eps", bbox_inches="tight", format="eps")

    plt.show()


0


KeyboardInterrupt: 

In [None]:
w_evol

[array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e+00], dtype=float32),
 array([ 4.0487480e+00,  2.1551768e-03, -2.0124655e

In [None]:
DW_STDP_list

[array([-0.00946133, -0.0071344 , -0.00045121]),
 array([0., 0., 0.]),
 array([0., 0., 0.]),
 array([0., 0., 0.]),
 array([ 0.00246462, -0.00490338, -0.01523501]),
 array([0., 0., 0.]),
 array([0.00713431, 0.00793741, 0.01520631]),
 array([0., 0., 0.]),
 array([-0.02371478, -0.01206277, -0.01098224]),
 array([0.00927471, 0.00899242, 0.01375398]),
 array([-0.03399891, -0.02099746, -0.02621896]),
 array([0., 0., 0.]),
 array([0., 0., 0.]),
 array([0.0049513 , 0.01160517, 0.00486292]),
 array([0., 0., 0.]),
 array([0., 0., 0.]),
 array([ 0.0196426 , -0.03379003, -0.0259291 ]),
 array([0.0152185 , 0.00922825, 0.01041559]),
 array([-0.01334607, -0.02800694, -0.00967384]),
 array([0., 0., 0.]),
 array([0., 0., 0.]),
 array([-0.0074668 , -0.01479827, -0.02794623]),
 array([-0.02779353, -0.02301871, -0.02184037]),
 array([0., 0., 0.]),
 array([-0.02962459, -0.00555259, -0.02378472]),
 array([0., 0., 0.]),
 array([0., 0., 0.]),
 array([0., 0., 0.]),
 array([0., 0., 0.]),
 array([-0.052659  , -0