In [73]:
import jax
from jax import numpy as jnp
from typing import Union
import pickle
import copy
from bsc_utils.miscellaneous import multimodal_normal_sampling
from bsc_utils.controller.base import ExplicitMLP
from bsc_utils.simulate.analyze import Simulator
from bsc_utils.miscellaneous import load_config_from_yaml
from bsc_utils.miscellaneous import complete_config_with_defaults, presynaptic_competition_rescale
import time
import numpy as np

from mediapy import show_video

from bsc_utils.visualization import post_render

rng = jax.random.PRNGKey(0)

In [74]:
### Configurables
simulation_time = 5

# consolidation network
num_hidden = 4 # number of hidden layers
nodes_hidden = 128 # number of nodes per hidden layer

# Kernel initialisation
means = [-3,3]
stds = [1,1]
trunc_mins = [-6,0]
trunc_maxs = [0,6]

# means = [0]
# stds = [0.5]
# trunc_mins = [-1]
# trunc_maxs = [1]


alpha = 0.1

In [75]:
with open("observations.pkl", "rb") as file:
    observations_all = pickle.load(file)
observations = jnp.concatenate(
    [observations_all[label] for label in ('joint_position', 'joint_actuator_force', 'segment_contact')],
    axis = 1
)
observations = observations[0].T

with open("nn_params_example.pkl", "rb") as file:
    nn_params_example = pickle.load(file)

actions = jnp.load("actions.npy")[0].T
rewards = jnp.load("rewards.npy")[0]

print(actions.shape)
print(rewards.shape)
print(observations.shape)
print(jax.tree.map(lambda x: x.shape, nn_params_example))

(250, 50)
(250,)
(250, 125)
{'params': {'layers_0': {'bias': (128,), 'kernel': (125, 128)}, 'layers_1': {'bias': (128,), 'kernel': (128, 128)}, 'layers_2': {'bias': (50,), 'kernel': (128, 50)}}}


In [76]:
# make explictMLP with way more layers and biases
# generate fictional rollout, at the end of every rollout, you manually set the output as desired

In [77]:
hidden_layers = []

for i in range(num_hidden):
    hidden_layers += [nodes_hidden]

input_dim = nn_params_example["params"]["layers_0"]["kernel"].shape[0]
output_dim = nn_params_example["params"]["layers_2"]["kernel"].shape[1]

layers = [input_dim] + hidden_layers + [output_dim]
print(layers)

[125, 128, 128, 128, 128, 50]


In [78]:
cons_contr = ExplicitMLP(features = layers[1:], joint_control = "position")

rng, rng_init = jax.random.split(rng, 2)
cons_contr_params = cons_contr.init(rng_init, jnp.zeros(layers[0]))

print(jax.tree.map(lambda x: x.size, cons_contr_params))

{'params': {'layers_0': {'bias': 128, 'kernel': 16000}, 'layers_1': {'bias': 128, 'kernel': 16384}, 'layers_2': {'bias': 128, 'kernel': 16384}, 'layers_3': {'bias': 128, 'kernel': 16384}, 'layers_4': {'bias': 50, 'kernel': 6400}}}


In [79]:
def identical(x):
    return x

In [80]:
def DNN_rollout(params_dict, input, act_hidden=jnp.tanh, act_output=jnp.tanh, joint_control="position"):
    num_layers = len(params_dict["params"].keys())
    neuron_activities = [input]
    for p in range(num_layers):
        kernel = params_dict["params"][f"layers_{p}"]["kernel"]
        bias = params_dict["params"][f"layers_{p}"]["bias"]
        output = jnp.dot(input,kernel)
        output = output+bias

        ### !!! activity which is stored is not yet rescaled by the tanh function
        neuron_activities.append(output)
        if p != num_layers-1:
            output = act_hidden(output)
        elif p == num_layers-1:
            assert joint_control in ['position', 'torque'], "joint_control should be either 'position' or 'torque'"
            if joint_control == 'position':
                output = 30*jnp.pi/180 * act_output(output) # the action space range for positions is -0.5236..0.5236
            elif joint_control == 'torque':
                output = act_output(output) # the action space range for torques is -1..1
        
        input = output
    return output, neuron_activities

In [81]:
def initialise_cons_contr_params(path, leaf):
    rng = jax.random.PRNGKey(2000)
    rng, rng_init = jax.random.split(rng, 2)
    if "kernel" in jax.tree_util.keystr(path):
        leaf_old = leaf
        sample_sizes = [int(leaf_old.size/len(means)) for _ in range(len(means))]
        values = multimodal_normal_sampling(rng,
                                            means,
                                            stds,
                                            sample_sizes,
                                            trunc_mins,
                                            trunc_maxs)
        leaf = jnp.reshape(values, leaf_old.shape)
        return leaf  # Modify leaf
    elif "bias" in jax.tree_util.keystr(path):
        leaf = leaf*0
        return leaf  # Modify leaf
    return leaf  # Keep unchanged otherwise

cons_contr_params = jax.tree_util.tree_map_with_path(initialise_cons_contr_params, cons_contr_params)
# print(cons_contr_params)


In [82]:
# def synaptic_competition(input_nodes,
#                          output_nodes,
#                          synapse_strengths,
#                          learning_rate = 0.1
#                          ):
#     synapse_strengths = np.array(synapse_strengths)
#     for i in range(len(output_nodes)):
#         for j in range(len(input_nodes)):
#             synapse_strengths[j,i] = synapse_strengths[j,i] + learning_rate * input_nodes[j] * output_nodes[i]

#     return jnp.array(synapse_strengths)

def apply_oja(input_nodes,
            output_nodes,
            synapse_strengths,
            alpha = 0.1
            ):
    
    """ 
    Input: 
    - learning rule kernel: dims (popsize, input_layer_dim, output_layer_dim, lr_dim = 5)
    - learning_rule: only "ABCD" implemented so far.
    Output: synaptic strength increment kernel: dims (popsize, input_layer_dim, output_layer_dim)
    This function is vmapable and jittable
    """
    in_dim = len(input_nodes)
    out_dim = len(output_nodes)
    inp = jnp.transpose(jnp.tile(input_nodes, (out_dim, 1))) # Generates (in_dim, out_dim) dimension, but constant along axis = 1 (output dimension)
    outp = jnp.tile(output_nodes, (in_dim, 1)) # Generates (in_dim, out_dim) dimension, but constant along axis = 0 (input dimesnion)
    # these dimensions are required for matrix multiplications

    # kernel content: [alpha, A, B, C, D] --> Dw_ij = alpha_ij * (A_ij*o_i*o_j + B_ij * o_i + C_ij * o_j + D_ij)

    ss_incr = alpha* inp * outp
    return ss_incr

    ss_incr = alpha * (inp*outp - synapse_strengths * outp**2) # 10*synapse_strengths*outp)
    return ss_incr

    ss_factor = jnp.exp(ss_incr)
    return ss_factor # synaptic strength multiplication factor


def update_synapse_strengths(
        synapse_strengths_input,
        neuron_activities,
        alpha = 0.1
):
    synapse_strengths = copy.deepcopy(synapse_strengths_input)

    num_layers = len(synapse_strengths["params"].keys())
    for p in range(num_layers):
        input_nodes = neuron_activities[p]
        output_nodes = neuron_activities[p+1]
        ss = synapse_strengths["params"][f"layers_{p}"]["kernel"]
        ss_factor = apply_oja(input_nodes, output_nodes, ss, alpha = alpha)
        # synapse_strengths["params"][f"layers_{p}"]["kernel"] = synapse_strengths["params"][f"layers_{p}"]["kernel"] * ss_factor
        synapse_strengths["params"][f"layers_{p}"]["kernel"] = synapse_strengths["params"][f"layers_{p}"]["kernel"] + ss_factor
        # synapse_strengths["params"][f"layers_{p}"]["kernel"] = synaptic_competition(input_nodes, output_nodes, synapse_strengths=ss)
    
    synapse_strengths = presynaptic_competition_rescale(synapse_strengths)

    return synapse_strengths


In [83]:

from flax import linen as nn

for i in range(observations.shape[0]):
    # print("observations: ", observations[i,:5])
    # _, neuron_activities = DNN_rollout(cons_contr_params, observations[i,:], act_hidden = nn.tanh, act_output=nn.tanh)
    _, neuron_activities = cons_contr.apply(cons_contr_params, observations[i,:], act_hidden = nn.tanh, act_output=nn.tanh)
    print("neuron_activities: ", neuron_activities[1])
    neuron_activities[-1] = actions[i,:]

    cons_contr_params = update_synapse_strengths(cons_contr_params, neuron_activities, alpha=alpha)
    print("weights: ", cons_contr_params["params"]["layers_0"]["kernel"][0,:7])
    # print(cons_contr_params)

neuron_activities:  [-1.         1.        -1.        -1.        -1.        -1.
  1.        -1.        -1.         1.         1.        -1.
  1.        -1.        -1.         1.        -1.         1.
 -1.        -1.         1.        -1.        -1.         1.
  1.        -1.         1.        -1.        -1.        -1.
 -1.         1.         1.        -1.        -1.         1.
 -1.        -1.        -1.         0.9999993 -1.         1.
 -1.         1.        -1.        -1.         1.        -1.
 -1.         1.         1.         1.         1.         1.
  1.         1.         1.         1.        -1.         1.
  1.         1.        -1.         1.         1.        -1.
  1.         1.         1.        -1.        -1.        -1.
 -1.         1.         1.        -1.        -1.        -1.
 -1.        -1.         1.        -1.        -1.         1.
  1.         1.         1.         1.         1.         1.
 -1.        -1.         1.        -1.         1.        -1.
 -1.         1.     

In [84]:
from bsc_utils.BrittleStarEnv import EnvContainer


config = load_config_from_yaml("2024_04_01_b01_r03.yaml")
config = complete_config_with_defaults(config)
config["environment"]["simulation_time"] = simulation_time

env_container = EnvContainer(config)
env_container.generate_env()
env_container.generate_env_damaged()
env_container.visualize_arena()
env_container.visualize_morphology()

In [None]:
rng, rng_reset = jax.random.split(rng, 2)
env_state = env_container.env.reset(rng_reset)
mjx_frames = []

i = 0
while not jnp.any(env_state.terminated | env_state.truncated):
    if i%10 == 0:
        print(i)
    i += 1
    
    sensory_input = jnp.concatenate(
        [env_state.observations[label] for label in config["environment"]["sensor_selection"]],
        # axis = 1
    )

    start_action = time.time()
    action,_ = cons_contr.apply(cons_contr_params, sensory_input)
    print(f"action rollout time = {time.time()-start_action}")

    start_env_step = time.time()
    env_state = env_container.env.step(state=env_state, action=action)
    print(f"env step time = {time.time()-start_env_step}")
    
    mjx_frames.append(
            post_render(
                env_container.env.render(state=env_state),
                env_container.env.environment_configuration
                )
            )
show_video(images=mjx_frames)

0
action rollout time = 0.008483409881591797


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 13.966931819915771
action rollout time = 0.009111642837524414


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 14.188769102096558
action rollout time = 0.007889986038208008


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 12.93776273727417
action rollout time = 0.008735895156860352


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 13.495914936065674
action rollout time = 0.013150453567504883


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 13.49060344696045
action rollout time = 0.008590221405029297


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 15.398575067520142
action rollout time = 0.00892949104309082


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 13.267178297042847
action rollout time = 0.009534120559692383


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 13.65910267829895
action rollout time = 0.01153421401977539


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 13.656423807144165
action rollout time = 0.00850534439086914


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 15.485229969024658
10
action rollout time = 0.008791446685791016


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 14.101575136184692
action rollout time = 0.008425712585449219


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 13.513196468353271
action rollout time = 0.00931096076965332


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 14.295178174972534
action rollout time = 0.009356975555419922


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 14.008476257324219
action rollout time = 0.009390592575073242


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 18.785033464431763
action rollout time = 0.008832931518554688


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 13.838455200195312
action rollout time = 0.009278297424316406


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 13.58275032043457
action rollout time = 0.029319286346435547


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 15.099056005477905
action rollout time = 0.010829448699951172


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 14.74323296546936
action rollout time = 0.012490034103393555


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 15.886333703994751
20
action rollout time = 0.008583784103393555


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 14.07695984840393
action rollout time = 0.01909780502319336


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 13.24793815612793
action rollout time = 0.009657621383666992


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 14.101662874221802
action rollout time = 0.012634038925170898


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


env step time = 15.271990299224854
action rollout time = 0.009781837463378906
