# Developping a framework for the centralised controller of a Brittle Star robot
Integrating the brittle star morphology and environment framework of the bio-inspired robotics benchmark with the Evosac implementations of neural networks.

In [8]:
import sys

print(sys.executable)

/apps/gent/RHEL8/cascadelake-ib/software/Python/3.11.3-GCCcore-12.3.0/bin/python


In [44]:
import evosax
import jax
from jax import random, numpy as jnp
from evosax import OpenES, ParameterReshaper, NetworkMapper
import flax
from flax import linen as nn
from typing import Any, Callable, Sequence
from evosax import ParameterReshaper
rng = jax.random.PRNGKey(0) # make an rng right away and every split throughout the document should make a new rng
# this new rng should only be used for the sole purpose of splitting in the future

In [23]:
# build NN architecture
class ExplicitMLP(nn.Module):
    features: Sequence[int]
    # len of features = number of layers
    # integers in the sequence determine the number of nodes

    def setup(
        self
    ):
        """
        Fully connected neural network, characterised by a pytree (dict containing dict with params and biases)
        Features represents the number of outputs of the Dense layer
        inputs based on the presented input later on
        after presenting input: kernel can be generated
        """
        self.layers = [nn.Dense(feat) for feat in self.features]



    def __call__(
        self,
        inputs,
        act_hidden: Callable = nn.tanh,
        act_output: Callable = nn.tanh
    ):
        """
        Returning the output of a layer for a given input.
        Don't directly call an instance of ExplicitMLP --> this method is called in the apply method.
        -----
        act_hidden: activation function applied to hidden layers: popular is nn.tanh or nn.relu
        act_output: activation function applied to output layer: popular is nn.tanh or nn.sigmoid
        """
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = act_hidden(x)
            else:
                x = act_output(x)
            return x

Let us base the development of the first controller on a brittle star with 2 arms and 5 segments, 10 segments in total, 20 degrees of freedom, meaning 20 joint angle positions to be controlled. This means the final layer of our output is 20.
The inputs will be sampled from the Brittle Star Observation space, but let's for now just initialise an input space with 10 random inputs.

TO DO: FETCH INPUT ARRAY AND INPUT DIMENSIONS FROM BRITTLE STAR ENVIRONMENT

In [37]:
arm_setup = [5,0,5,0,0] # 2 arms with 5 segments
dofs = 2*sum(arm_setup)
print(dofs)
model = ExplicitMLP(features = [dofs])

# initialising the parameters of the model
rng, rng_input, rng_init = random.split(rng, 3)
x = random.normal(rng_input, (10,))
print(x)
params = model.init(rng_init, x)
# params is a PyTree --> see jax documentation
# print(params)
print(jax.tree_util.tree_map(lambda x: x.shape, params))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))

20
[ 0.06338525  0.32473582  0.02520137  0.17405976  1.6556118  -0.87135386
  0.6929723  -0.6462938  -1.7196621   0.01410782]
{'params': {'layers_0': {'bias': (20,), 'kernel': (10, 20)}}}
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (20,), 'kernel': (10, 20)}}}


A single forward pass through the model

In [35]:
model.apply(params,x)

Array([-0.6768114 ,  0.5990112 ,  0.5021584 ,  0.07781484, -0.82217586,
       -0.68930304,  0.77736187, -0.00886519, -0.777349  , -0.97340053,
       -0.7469593 , -0.26257956,  0.43478754, -0.2912326 , -0.60248494,
       -0.64250827,  0.8634078 , -0.07302105,  0.45532835,  0.8936004 ],      dtype=float32)

Indeed, all outputs are distributed between -1 and +1, as expected from the tanh activation fucntion of the output layer

In [52]:
param_reshaper = ParameterReshaper(params)
num_params = param_reshaper.total_params # get from the weights and biases of the NN

ParameterReshaper: 220 parameters detected for optimization.


In [59]:
# instantiate the search strategy

rng, rng_init = jax.random.split(rng, 2)
strategy  = OpenES(popsize = 100, num_dims = num_params)
# still parameters that can be finetuned, like optimisation method, lrate, lrate decay, ...
es_params = strategy.default_params
# # replacing certain parameters:
# es_params = es_params.replace(init_min = -3, init_max = 3)
print(es_params)

state = strategy.initialize(rng_init, es_params)

candidate, state = strategy.ask(rng, state)
print('candidate solution shape', candidate.shape)
print('candidate solutions', candidate)

net_params = param_reshaper.reshape(candidate)
print('net_params: ', net_params)


EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=1.0, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, beta_3=None, eps=1e-08, max_speed=None), sigma_init=0.03, sigma_decay=1.0, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)
candidate solution shape (100, 220)
candidate solutions [[ 0.00152383 -0.00245873 -0.01534315 ... -0.00033993 -0.00899743
   0.04514204]
 [ 0.02709437  0.01101336  0.00762423 ... -0.02867637 -0.0827048
   0.00018243]
 [ 0.03824715 -0.0268039   0.00814435 ... -0.06754368  0.01677538
  -0.0185631 ]
 ...
 [-0.03424363  0.02744813  0.04431093 ...  0.00107031 -0.02235126
  -0.02651232]
 [ 0.03676725  0.01331418  0.01328681 ...  0.01505064  0.01554712
   0.01992323]
 [ 0.01600746  0.01080545  0.00724279 ... -0.00690854  0.04143371
  -0.00817915]]
net_params:  {'params': {'layers_0': {'bias': Array([[ 0.00152383, -0.00245873, -0.01534315, ..., -0.03308293,
         0.00207618,  0.0240575 ],
       [ 0.

In [None]:
num_generations = 2500
# print_every_k_gens = 100 --> replace by wandb monitoring
# Run ask-eval-tell loop - NOTE: By default minimization!
for gen in range(num_generations):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state, es_params)
    fitness = ...  # Your population evaluation fct 
    state = strategy.tell(x, fitness, state, es_params)

# Get best overall population member & its fitness
state.best_member, state.best_fitness

In [3]:
try:
    import brb
except ImportError:
    !{sys.executable} -m pip install git+https://github.com/Co-Evolve/brb@new-framework

try:
    import wandb
except:
    !{sys.executable} -m pip install wandb

## Checking accesibility GPU

In [4]:
import os
import subprocess
import logging

try:
    if subprocess.run('nvidia-smi').returncode:
        raise RuntimeError(
                'Cannot communicate with GPU. '
                'Make sure you are using a GPU Colab runtime. '
                'Go to the Runtime menu and select Choose runtime type.'
                )

    # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
    # This is usually installed as part of an Nvidia driver package, but the Colab
    # kernel doesn't install its driver via APT, and as a result the ICD is missing.
    # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
    NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
    if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
        with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
            f.write(
                    """{
                            "file_format_version" : "1.0.0",
                            "ICD" : {
                                "library_path" : "libEGL_nvidia.so.0"
                            }
                        }
                        """
                    )

    # Configure MuJoCo to use the EGL rendering backend (requires GPU)
    print('Setting environment variable to use GPU rendering:')
    %env MUJOCO_GL=egl

    # Check if jax finds the GPU
    import jax

    print(jax.devices('gpu'))
except Exception:
    logging.warning("Failed to initialize GPU. Everything will run on the cpu.")

try:
    print('Checking that the mujoco installation succeeded:')
    import mujoco

    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise e from RuntimeError(
            'Something went wrong during installation. Check the shell output above '
            'for more information.\n'
            'If using a hosted Colab runtime, make sure you enable GPU acceleration '
            'by going to the Runtime menu and selecting "Choose runtime type".'
            )

print('MuJoCo installation successful.')

Sat Jan 27 19:07:33 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A2                      On  | 00000000:3B:00.0 Off |                    0 |
|  0%   41C    P8               8W /  60W |      4MiB / 15356MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    



Checking that the mujoco installation succeeded:
MuJoCo installation successful.


## Weights and Biases trial

In [5]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project = "my-awesome-project",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.02,
    "architecture": "CNN",
    "dataset": "CIFAR-100",
    "epochs": 10,
    }
)

# simulate training
epochs = 10
offset = random.random()/5
for epoch in range(2, epochs):
    acc = 1 - 2 ** -epoch - random.random() / epoch - offset
    loss = 2** -epoch + random.random() / epoch + offset

    # log memtrics to wandb
    wandb.log({"acc": acc, "loss":loss})

# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmatthias-pex[0m. Use [1m`wandb login --relogin`[0m to force relogin




VBox(children=(Label(value='0.005 MB of 0.013 MB uploaded\r'), FloatProgress(value=0.3774859287054409, max=1.0…

0,1
acc,▁▄▇█▆▆▇█
loss,▇█▄▃▄▂▁▁

0,1
acc,0.79047
loss,0.18106
