# Notebook to build the Hebbian Learning module

In [1]:
import sys
print(sys.executable)
sys.path.insert(0,'C:\\Users\\Matthias\\OneDrive - UGent\\Documents\\DOCUMENTEN\\3. Thesis\\BSC\\')

c:\Users\Matthias\OneDrive - UGent\Documents\DOCUMENTEN\3. Thesis\BSC\bsc\Scripts\python.exe


In [2]:
import chex
from typing import Sequence
import copy

import jax
from jax import numpy as jnp
from evosax import ParameterReshaper, OpenES

from bsc_utils.controller.base import ExplicitMLP

In [3]:
rng = jax.random.PRNGKey(0)

jnp.set_printoptions(precision = 15, suppress = False)

In [4]:
features = [4,5,2] # 2 hidden layers of 4 and 5 neurons, output dim is 2. Input dim is defined by the input vector used for instantiation
rng, rng_input = jax.random.split(rng, 2)
input = jax.random.normal(rng_input, (3,))
print(input)
joint_control = 'position'

[ 1.1378784  -1.2209548  -0.59153634]


In [5]:
model = ExplicitMLP(features = features, joint_control = joint_control) # uses the ExplicitMLP I have defined before

In [6]:
print(model)
rng, rng_init = jax.random.split(rng, 2)
params_init = model.init(rng_init, input)
print(params_init)
print(jax.tree_util.tree_map(lambda x: x.shape, params_init)) # tree_map treats np.arrays as a single leaf, whereas lists, tuples, dicts are nodes (nodes are in itself pytrees)
# the leaves of a tuple/list are the elements of that tuple/list

ExplicitMLP(
    # attributes
    features = [4, 5, 2]
    joint_control = 'position'
)
{'params': {'layers_0': {'kernel': Array([[ 0.91562444 , -0.043203816,  1.06029    ,  1.0857037  ],
       [-0.015480848,  1.1917892  , -0.6570196  ,  1.1079454  ],
       [ 0.2895664  , -0.24119082 ,  1.0076706  , -0.9029162  ]],      dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)}, 'layers_1': {'kernel': Array([[ 0.42656583 , -0.0712561  ,  0.23940565 , -0.30804867 ,
        -0.780432   ],
       [-0.0604636  ,  0.54123    ,  0.78129745 ,  0.02796052 ,
        -0.013180583],
       [-0.3252826  ,  0.55973214 ,  0.16509312 ,  0.2372554  ,
        -0.43057743 ],
       [ 0.10360184 ,  0.7279873  ,  0.2179734  ,  0.07189767 ,
        -0.97823316 ]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}, 'layers_2': {'kernel': Array([[ 0.71163124, -0.5308686 ],
       [-0.10842434,  0.75031364],
       [ 0.3031431 , -0.297253  ],
       [ 1.0005071 , -0.2076626 ],
       

In [8]:
# model apply now provides neuron_activities and the actual provided action seperately

x, neuron_activities = model.apply(params_init, input)
print(x)
print(neuron_activities)

[-0.29033384  0.28768677]
[Array([ 1.1378784 , -1.2209548 , -0.59153634], dtype=float32), Array([ 0.71113765, -0.87676555,  0.8880447 ,  0.39419347], dtype=float32), Array([ 0.10791126,  0.25320113, -0.27496824, -0.00454481, -0.86463517],      dtype=float32), Array([-0.29033384,  0.28768677], dtype=float32)]


In [7]:
tup = (0,1)
tup_ext = tuple(list(tup)+[2])
print(tup_ext)

(0, 1, 2)


In [8]:
def add_learning_rule_dim(x, n=5):
    """
    Apply on pytrees containtining 2D kernals and 1D bias arrays
    Adds a learning rule of n parameters to each kernel in an additional dimension
    Only applies this to kernels, not to bias arrays
    each kernel of shape (i,j) becomes a kernel of shape (i,j,n)
    each bias of shape (m,) remains bias of shape (m,)
    RETURNED ARRAY CONTAINS ZEROES
    """
    if len(x.shape) == 2:
        new_shape = tuple(list(x.shape)+[n])
        return jnp.zeros(new_shape)
    else:
        return jnp.zeros_like(x)

In [9]:
x = jnp.array([[1,2,3],[4,5,6]])
y = jnp.array([7,8,9,10])

print(x)
print(y)

print(add_learning_rule_dim(x))
print(add_learning_rule_dim(y))


[[1 2 3]
 [4 5 6]]
[ 7  8  9 10]
[[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]]
[0 0 0 0]


In [10]:
learning_rules_empty = jax.tree_util.tree_map(lambda x: add_learning_rule_dim(x,n=5), params_init)
print(jax.tree_util.tree_map(lambda x: x.shape, learning_rules_empty))
print(learning_rules_empty["params"]["layers_1"])
                                              

{'params': {'layers_0': {'bias': (4,), 'kernel': (3, 4, 5)}, 'layers_1': {'bias': (5,), 'kernel': (4, 5, 5)}, 'layers_2': {'bias': (2,), 'kernel': (5, 2, 5)}}}
{'bias': Array([0., 0., 0., 0., 0.], dtype=float32), 'kernel': Array([[[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]], dtype=float32)}


In [11]:
param_reshaper = ParameterReshaper(learning_rules_empty)
num_params = param_reshaper.total_params
print(f"num_params in learning rule tree: {num_params}")
# to compare)
param_reshaper_synapse_strength = ParameterReshaper(params_init)
num_params_synapse_strengths = param_reshaper_synapse_strength.total_params
print(f"num_params in synapse strength tree: {num_params_synapse_strengths}")
print(f"ratio of params increase: {num_params/num_params_synapse_strengths}")

ParameterReshaper: 221 parameters detected for optimization.
num_params in learning rule tree: 221
ParameterReshaper: 53 parameters detected for optimization.
num_params in synapse strength tree: 53
ratio of params increase: 4.169811320754717


In [12]:
# get learning rules from evosax
popsize = 2

rng, rng_ask, rng_init = jax.random.split(rng, 3)
strategy  = OpenES(popsize = popsize, num_dims = num_params)
es_params = strategy.default_params

es_state = strategy.initialize(rng_init, es_params)
learning_rules_flat, es_state = strategy.ask(rng_ask, es_state)
print(f"candidate solution shape: {learning_rules_flat.shape}")

learning_rules = param_reshaper.reshape(learning_rules_flat)
print(f"""
      reshaped policy params in jax:
      {jax.tree_util.tree_map(lambda x: x.shape, learning_rules)}
""")


candidate solution shape: (2, 221)

      reshaped policy params in jax:
      {'params': {'layers_0': {'bias': (2, 4), 'kernel': (2, 3, 4, 5)}, 'layers_1': {'bias': (2, 5), 'kernel': (2, 4, 5, 5)}, 'layers_2': {'bias': (2, 2), 'kernel': (2, 5, 2, 5)}}}



In [13]:
rng, rng_unif = jax.random.split(rng, 2)
synapse_strengths_init_flat = jax.random.uniform(rng_unif, shape=(popsize, num_params_synapse_strengths), minval = -0.1, maxval = 0.1) # based on Pedersen & Risi (2021)
print(f"synaptic strength initialisation: {synapse_strengths_init_flat.shape}")

synapse_strengths_init = param_reshaper_synapse_strength.reshape(synapse_strengths_init_flat)
print(f"""
      reshaped policy params in jax:
      {jax.tree_util.tree_map(lambda x: x.shape, synapse_strengths_init)}
""")
print(synapse_strengths_init)

synaptic strength initialisation: (2, 53)

      reshaped policy params in jax:
      {'params': {'layers_0': {'bias': (2, 4), 'kernel': (2, 3, 4)}, 'layers_1': {'bias': (2, 5), 'kernel': (2, 4, 5)}, 'layers_2': {'bias': (2, 2), 'kernel': (2, 5, 2)}}}

{'params': {'layers_0': {'bias': Array([[ 0.08441029  , -0.07492232  , -0.05319176  , -0.03989966  ],
       [ 0.0048117638,  0.051953316 , -0.09334455  , -0.09554732  ]],      dtype=float32), 'kernel': Array([[[ 0.06805022  ,  0.023456668 , -0.033478856 ,  0.08947175  ],
        [ 0.07558513  , -0.027463222 ,  0.04407022  , -0.03388474  ],
        [ 0.046473026 , -0.097225785 ,  0.049788833 , -0.071477346 ]],

       [[ 0.07617481  , -0.090289235 ,  0.08835657  ,  0.07394917  ],
        [-0.0667928   ,  0.07353234  , -0.07089801  ,  0.07646527  ],
        [-0.0061877253,  0.041908145 , -0.08476386  , -0.026584268 ]]],      dtype=float32)}, 'layers_1': {'bias': Array([[-0.046349265, -0.012251282, -0.09816499 ,  0.049443055,
         0.03

In [14]:
input_stack = jnp.tile(input, (popsize, 1))
neuron_activities = [jnp.zeros((2, n)) for n in features]
neuron_activities = [input_stack] + neuron_activities
print(neuron_activities)
print(jax.tree_util.tree_map(lambda x: x.shape, neuron_activities))


[Array([[ 1.1378784 , -1.2209548 , -0.59153634],
       [ 1.1378784 , -1.2209548 , -0.59153634]], dtype=float32), Array([[0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32), Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32), Array([[0., 0.],
       [0., 0.]], dtype=float32)]
[(2, 3), (2, 4), (2, 5), (2, 2)]


In [15]:
def apply_learning_rule(lr_kernel, input_nodes, output_nodes):
    """ 
    Input: learning rule kernel: dims (popsize, input_layer_dim, output_layer_dim, lr_dim = 5)
    Output: synaptic strength increment kernel: dims (popsize, input_layer_dim, output_layer_dim)
    This function is vmapable and jittable
    """
    assert lr_kernel.shape[-1] == 5, "Learning rule requires 5 parameters or needs to be updated so it is compatible with different number of parameters"
    in_dim = len(input_nodes)
    out_dim = len(output_nodes)
    inp = jnp.transpose(jnp.tile(input_nodes, (out_dim, 1))) # Generates (in_dim, out_dim) dimension, but constant along axis = 1 (output dimesnion)
    outp = jnp.tile(output_nodes, (in_dim, 1)) # Generates (in_dim, out_dim) dimension, but constant along axis = 0 (input dimesnion)

    # kernel content: [alpha, A, B, C, D] --> Dw_ij = alpha_ij * (A_ij*o_i*o_j + B_ij * o_i + C_ij * o_j + D_ij)
    alpha = lr_kernel[:,:,0]
    A = lr_kernel[:,:,1]
    B = lr_kernel[:,:,2]
    C = lr_kernel[:,:,3]
    D = lr_kernel[:,:,4]
    ss_incr_kernel = alpha * (A*inp*outp + B*inp + C*outp + D)

    return ss_incr_kernel # synaptic strength increment kernel

apply_learning_rule_vect = jax.jit(jax.vmap(apply_learning_rule))

In [16]:
synapse_strengths_init_save = copy.deepcopy(synapse_strengths_init)


In [27]:
def update_synapse_strengths(
        synapse_strengths_input: dict,
        learning_rules: dict, # pytree
        neuron_activities: Sequence[chex.Array] # jax array
        ) -> dict:
    # problem: for some reason adjusting the dicts in this function causes the global value to be affected. For security, use copy.deepcopy
    synapse_strengths = copy.deepcopy(synapse_strengths_input)

    num_layers = len(learning_rules["params"].keys())
    for p in range(num_layers):
        lr_kernel = learning_rules["params"][f"layers_{p}"]["kernel"]
        input_nodes = neuron_activities[p]
        output_nodes = neuron_activities[p+1]

        ss_incr_kernel = apply_learning_rule_vect(lr_kernel, input_nodes, output_nodes)

        synapse_strengths["params"][f"layers_{p}"]["kernel"] += ss_incr_kernel
        synapse_strengths["params"][f"layers_{p}"]["bias"] = learning_rules["params"][f"layers_{p}"]["bias"]

    return synapse_strengths


synapse_strengths = update_synapse_strengths(synapse_strengths_init, learning_rules, neuron_activities)
print(jax.tree_util.tree_map(lambda x: x.shape, synapse_strengths))
print(jax.tree_util.tree_map(lambda x,y: x-y, synapse_strengths_init_save, synapse_strengths))

{'params': {'layers_0': {'bias': (2, 4), 'kernel': (2, 3, 4)}, 'layers_1': {'bias': (2, 5), 'kernel': (2, 4, 5)}, 'layers_2': {'bias': (2, 2), 'kernel': (2, 5, 2)}}}
{'params': {'layers_0': {'bias': Array([[ 0.095749214 , -0.08305403  , -0.051500898 , -0.027483817 ],
       [-0.0065271594,  0.06008502  , -0.09503541  , -0.10796316  ]],      dtype=float32), 'kernel': Array([[[ 0.0003116727 ,  0.001973167  ,  0.0034631677 , -0.023116484  ],
        [ 0.00055105984,  0.0043627936 ,  0.016441975  ,  0.033707418  ],
        [-0.0042980723 , -0.015313536  ,  0.0009611361 , -0.008618005  ]],

       [[ 0.0003116727 ,  0.001973167  ,  0.0034631342 , -0.023116484  ],
        [ 0.00055105984,  0.004362814  ,  0.016441941  ,  0.03370741   ],
        [-0.004298064  , -0.015313536  ,  0.0009611696 , -0.008617988  ]]],      dtype=float32)}, 'layers_1': {'bias': Array([[-0.07035379  , -0.039831672 , -0.07514771  , -0.007336233 ,
         0.031622782 ],
       [-0.029707434 , -0.0051215794, -0.1209650