# Developping a framework for the centralised controller of a Brittle Star robot
Integrating the brittle star morphology and environment framework of the bio-inspired robotics benchmark with the Evosac implementations of neural networks.

In [2]:
import sys

print(sys.executable)

/user/gent/457/vsc45787/.conda/envs/hope/bin/python


In [3]:
import numpy as np
import jax
from jax import numpy as jnp
import evosax
from evosax import OpenES, ParameterReshaper, NetworkMapper, ParameterReshaper
import flax
from flax import linen as nn
from typing import Any, Callable, Sequence, Union, List
import brb

import wandb
print(wandb.__path__)
%env "WANDB_NOTEBOOK_NAME" "Centralized controller framework"

rng = jax.random.PRNGKey(0) # make an rng right away and every split throughout the document should make a new rng
# this new rng should only be used for the sole purpose of splitting in the future


['/user/gent/457/vsc45787/.conda/envs/hope/lib/python3.11/site-packages/wandb']
env: "WANDB_NOTEBOOK_NAME"="Centralized controller framework"


## Checking accesibility GPU

In [4]:
import os
import subprocess
import logging

try:
    if subprocess.run('nvidia-smi').returncode:
        raise RuntimeError(
                'Cannot communicate with GPU. '
                'Make sure you are using a GPU Colab runtime. '
                'Go to the Runtime menu and select Choose runtime type.'
                )

    # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
    # This is usually installed as part of an Nvidia driver package, but the Colab
    # kernel doesn't install its driver via APT, and as a result the ICD is missing.
    # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
    NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
    if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
        with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
            f.write(
                    """{
                            "file_format_version" : "1.0.0",
                            "ICD" : {
                                "library_path" : "libEGL_nvidia.so.0"
                            }
                        }
                        """
                    )

    # Configure MuJoCo to use the EGL rendering backend (requires GPU)
    print('Setting environment variable to use GPU rendering:')
    %env MUJOCO_GL=egl

    # Check if jax finds the GPU
    import jax

    print(jax.devices('gpu'))
except Exception:
    logging.warning("Failed to initialize GPU. Everything will run on the cpu.")

try:
    print('Checking that the mujoco installation succeeded:')
    import mujoco

    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise e from RuntimeError(
            'Something went wrong during installation. Check the shell output above '
            'for more information.\n'
            'If using a hosted Colab runtime, make sure you enable GPU acceleration '
            'by going to the Runtime menu and selecting "Choose runtime type".'
            )

print('MuJoCo installation successful.')

Wed Jan 31 18:01:27 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A2                      On  | 00000000:3B:00.0 Off |                    0 |
|  0%   44C    P0              19W /  60W |  11250MiB / 15356MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

## Weights and Biases trial

In [5]:
# # start a new wandb run to track this script
# rng, rng_acc, rng_loss, rng_offset = jax.random.split(rng, 4)
# wandb.init(
#     # set the wandb project where this run will be logged
#     project = "my-awesome-project",

#     # track hyperparameters and run metadata
#     config={
#     "learning_rate": 0.02,
#     "architecture": "CNN",
#     "dataset": "CIFAR-100",
#     "epochs": 10,
#     }
# )

# # simulate training
# epochs = 10
# offset = jax.random.uniform(rng_offset)/5
# for epoch in range(2, epochs):
#     acc = 1 - 2 ** -epoch - jax.random.uniform(rng_acc) / epoch - offset
#     loss = 2** -epoch + jax.random.uniform(rng_loss) / epoch + offset

#     # log memtrics to wandb
#     wandb.log({"acc": acc, "loss":loss})

# # [optional] finish the wandb run, necessary in notebooks
# wandb.finish()

## Defining classes and functionalities

### Generating neural network

In [6]:
# build NN architecture
class ExplicitMLP(nn.Module):
    features: Sequence[int]
    # act_hidden: Callable = nn.tanh,
    # act_output: Callable = nn.tanh

    """
    features: number of outputs (# nodes) for each layer. The number of inputs of the first layer defined by the call function.
    The number of inputs of the hidden layers and output layer defined by the number of outputs of the previous layer.
    act_hidden: activation function applied to hidden layers: popular is nn.tanh or nn.relu
    act_output: activation function applied to output layer: popular is nn.tanh or nn.sigmoid
    """

    def setup(
        self
    ):
        """
        Fully connected neural network, characterised by a pytree (dict containing dict with params and biases)
        Features represents the number of outputs of the Dense layer
        inputs based on the presented input later on
        after presenting input: kernel can be generated
        """
        self.layers = [nn.Dense(feat) for feat in self.features]



    def __call__(
        self,
        inputs,
        act_hidden: Callable = nn.tanh,
        act_output: Callable = nn.tanh
    ):
        """
        Returning the output of a layer for a given input.
        Don't directly call an instance of ExplicitMLP --> this method is called in the apply method.
        -----
        
        """
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = act_hidden(x)
            else:
                x = act_output(x)
        return x

### Utilities for visualisation

In [7]:
from mujoco_utils.environment.base import MuJoCoEnvironmentConfiguration
from mujoco_utils.mjcf_utils import MJCFRootComponent

# Graphics and plotting.
ffmpeg_v = !command -v ffmpeg
assert "command not found" not in ffmpeg_v, f"Please install FFmpeg for visualizations."
!{sys.executable} -m pip install -q mediapy
import mediapy as media

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)
jnp.set_printoptions(precision=3, suppress=True, linewidth=100)


def visualize_mjcf(
        mjcf: MJCFRootComponent
        ) -> None:
    model = mujoco.MjModel.from_xml_string(mjcf.get_mjcf_str())
    data = mujoco.MjData(model)
    renderer = mujoco.Renderer(model)
    mujoco.mj_forward(model, data)
    renderer.update_scene(data)
    media.show_image(renderer.render())


def post_render(
        render_output: List[np.ndarray],
        environment_configuration: MuJoCoEnvironmentConfiguration
        ) -> np.ndarray:
    if render_output is None:
        # Temporary workaround until https://github.com/google-deepmind/mujoco/issues/1379 is fixed
        return None

    num_cameras = len(environment_configuration.camera_ids)
    num_envs = len(render_output) // num_cameras

    if num_cameras > 1:
        # Horizontally stack frames of the same environment
        frames_per_env = np.array_split(render_output, num_envs)
        render_output = [np.concatenate(env_frames, axis=1) for env_frames in frames_per_env]

    # Vertically stack frames of different environments
    render_output = np.concatenate(render_output, axis=0)

    return render_output[:, :, ::-1]  # RGB to BGR


def show_video(
        images: List[np.ndarray | None]
        ) -> str | None:
    # Temporary workaround until https://github.com/google-deepmind/mujoco/issues/1379 is fixed
    filtered_images = [image for image in images if image is not None]
    num_nones = len(images) - len(filtered_images)
    if num_nones > 0:
        logging.warning(f"env.render produced {num_nones} None's. Resulting video might be a bit choppy (consquence of https://github.com/google-deepmind/mujoco/issues/1379).")
    return media.show_video(images=filtered_images)

### The Brittle Star Environment

In [8]:
# Creating morphology
import gymnasium
from brb.brittle_star.mjcf.morphology.specification.default import default_brittle_star_morphology_specification
from brb.brittle_star.mjcf.morphology.morphology import MJCFBrittleStarMorphology
from brb.brittle_star.mjcf.morphology.specification.specification import BrittleStarMorphologySpecification

def create_morphology(
        morphology_specification: BrittleStarMorphologySpecification
        ) -> MJCFBrittleStarMorphology:
    morphology = MJCFBrittleStarMorphology(
            specification=morphology_specification
            )
    return morphology


In [9]:
# Creating arena
from brb.brittle_star.mjcf.arena.aquarium import AquariumArenaConfiguration, MJCFAquariumArena

def create_arena(
        arena_configuration: AquariumArenaConfiguration
        ) -> MJCFAquariumArena:
    arena = MJCFAquariumArena(
            configuration=arena_configuration
            )
    return arena

In [10]:
# Creating environment
from brb.brittle_star.environment.light_escape.shared import BrittleStarLightEscapeEnvironmentConfiguration
from brb.brittle_star.environment.directed_locomotion.shared import \
    BrittleStarDirectedLocomotionEnvironmentConfiguration
from brb.brittle_star.environment.undirected_locomotion.shared import \
    BrittleStarUndirectedLocomotionEnvironmentConfiguration


from brb.brittle_star.environment.undirected_locomotion.dual import BrittleStarUndirectedLocomotionEnvironment
from brb.brittle_star.environment.directed_locomotion.dual import BrittleStarDirectedLocomotionEnvironment
from brb.brittle_star.environment.light_escape.dual import BrittleStarLightEscapeEnvironment
from mujoco_utils.environment.dual import DualMuJoCoEnvironment
from mujoco_utils.environment.base import MuJoCoEnvironmentConfiguration


def create_environment(
        morphology_specification: BrittleStarMorphologySpecification,
        arena_configuration: AquariumArenaConfiguration,
        environment_configuration: MuJoCoEnvironmentConfiguration,
        backend: str
        ) -> DualMuJoCoEnvironment:
    assert backend in ["MJC", "MJX"], "Please specify a valid backend; Either 'MJC' or 'MJX'"

    morphology = create_morphology(
            morphology_specification=morphology_specification
            )
    arena = create_arena(
            arena_configuration=arena_configuration
            )
    if isinstance(environment_configuration, BrittleStarUndirectedLocomotionEnvironmentConfiguration):
        env_class = BrittleStarUndirectedLocomotionEnvironment
    elif isinstance(environment_configuration, BrittleStarDirectedLocomotionEnvironmentConfiguration):
        env_class = BrittleStarDirectedLocomotionEnvironment
    else:
        env_class = BrittleStarLightEscapeEnvironment

    env = env_class.from_morphology_and_arena(
            morphology=morphology, arena=arena, configuration=environment_configuration, backend=backend
            )
    return env

## Generating instances

### Instantiating the environment

In [11]:
arm_setup = [5,0,5,0,0] # 2 arms with 5 segments
dofs = 2*sum(arm_setup)
print(dofs)
print(len(arm_setup))


reward_type = "distance" # choose "distance", "target", "light"
num_physics_steps_per_control_step=10 
simulation_time=5 # [seconds]
joint_randomization_noise_scale=0.0
# If this value is > 0 then we will add randomly sampled noise to the initial joint positions and velocities


# specifying morphology
morphology_specification = default_brittle_star_morphology_specification(
        num_arms=len(arm_setup), num_segments_per_arm=arm_setup, use_p_control=True, use_torque_control=False
        )
morphology = create_morphology(morphology_specification=morphology_specification)
visualize_mjcf(mjcf=morphology)


# specifying arena
arena_configuration = AquariumArenaConfiguration(
        size=(10, 5), sand_ground_color=True, attach_target=False, wall_height=1.5, wall_thickness=0.1
        )
arena = create_arena(arena_configuration=arena_configuration)
visualize_mjcf(mjcf=arena)


# specifying environment: CHOOSE 1
assert reward_type in (["distance","target","light"]), "reward_type must be one of 'distance', 'target', 'light'"

if reward_type == "distance":
    environment_configuration = BrittleStarUndirectedLocomotionEnvironmentConfiguration(
        # If this value is > 0 then we will add randomly sampled noise to the initial joint positions and velocities 
        joint_randomization_noise_scale=joint_randomization_noise_scale,
        render_mode="rgb_array",  # Visualization mode 
        simulation_time=simulation_time,  # Number of seconds per episode 
        num_physics_steps_per_control_step=num_physics_steps_per_control_step,  # Number of physics substeps to do per control step 
        time_scale=2,    # Integer factor by which to multiply the original physics timestep of 0.002,
        camera_ids=[0, 1],   # Which camera's to render (all the brittle star environments contain 2 cameras: 1 top-down camera and one close-up camera that follows the brittle star),
        render_size=(480, 640)  # Resolution to render with ((height, width) in pixels)
        )

elif reward_type == "target":
    environment_configuration = BrittleStarDirectedLocomotionEnvironmentConfiguration(
        # Distance to put our target at (targets are spawned on a circle around the starting location with this given radius).
        target_distance=3.0,
        joint_randomization_noise_scale=joint_randomization_noise_scale,
        render_mode="rgb_array",
        simulation_time=simulation_time,
        num_physics_steps_per_control_step=num_physics_steps_per_control_step,
        time_scale=2,
        camera_ids=[0, 1],
        render_size=(480, 640)
        )

elif reward_type == "light":
    environment_configuration = BrittleStarLightEscapeEnvironmentConfiguration(
            joint_randomization_noise_scale=joint_randomization_noise_scale,
            light_perlin_noise_scale=0,
            # If this value is > 0, we will add perlin noise to the generated light map. Otherwise, the light map is a simple linear gradient.
            # Please only provide integer factors of 200.
            render_mode="rgb_array",
            simulation_time=simulation_time,
            num_physics_steps_per_control_step=num_physics_steps_per_control_step,
            time_scale=2,
            camera_ids=[0, 1],
            render_size=(480, 640)
            )

BACKEND = "MJX"

# useful environment configuration information
print(f"[simulation_time] The total amount of time (in seconds) that one simulation episode takes: {environment_configuration.simulation_time}")
print(f"[physics_timestep] The amount of time (in seconds) that one 'physics step' advances the physics: {environment_configuration.physics_timestep}")
print(f"[control_timestep] The amount of time (in seconds) that one 'control step' advances the physics: {environment_configuration.control_timestep}")
print(f"[total_num_physics_steps] The total amount of physics steps that happen during one simulation episode: {environment_configuration.total_num_physics_steps}")
print(f"[total_num_control_steps] The total amount of control steps that happen during one simulation episode: {environment_configuration.total_num_control_steps}")

20
5


[simulation_time] The total amount of time (in seconds) that one simulation episode takes: 5
[physics_timestep] The amount of time (in seconds) that one 'physics step' advances the physics: 0.004
[control_timestep] The amount of time (in seconds) that one 'control step' advances the physics: 0.04
[total_num_physics_steps] The total amount of physics steps that happen during one simulation episode: 1250
[total_num_control_steps] The total amount of control steps that happen during one simulation episode: 125


In [12]:
env = create_environment(
        morphology_specification=morphology_specification,
        arena_configuration=arena_configuration,
        environment_configuration=environment_configuration,
        backend=BACKEND
        )



if BACKEND == "MJC":
    rng = np.random.RandomState(0)
else:
    rng = jax.random.PRNGKey(seed=0)

state = env.reset(rng=rng)  # Always need to reset the environment before doing anything else with it
frame = env.render(state=state)
media.show_image(post_render(render_output=frame, environment_configuration=environment_configuration))


In [13]:
print("MJX:")
print(f"\t{env.observation_space}")

print("MJX:")
print(f"\t{env.action_space}")

print("First 5 actuators:")
print(f"\tMJX: {env.actuators[:5]}")


MJX:
	Dict('in_plane_joint_position': Box(-0.5235988, 0.5235988, (10,), <class 'jax.numpy.float32'>), 'out_of_plane_joint_position': Box(-0.5235988, 0.5235988, (10,), <class 'jax.numpy.float32'>), 'in_plane_joint_velocity': Box(-inf, inf, (10,), <class 'jax.numpy.float32'>), 'out_of_plane_joint_velocity': Box(-inf, inf, (10,), <class 'jax.numpy.float32'>), 'segment_contact': Box(0.0, 1.0, (10,), <class 'jax.numpy.float32'>), 'disk_position': Box(-inf, inf, (3,), <class 'jax.numpy.float32'>), 'disk_rotation': Box(-3.1415927, 3.1415927, (3,), <class 'jax.numpy.float32'>), 'disk_linear_velocity': Box(-inf, inf, (3,), <class 'jax.numpy.float32'>), 'disk_angular_velocity': Box(-inf, inf, (3,), <class 'jax.numpy.float32'>))
MJX:
	Box(-0.5235988, 0.5235988, (20,), <class 'jax.numpy.float32'>)
First 5 actuators:
	MJX: ['BrittleStarMorphology/arm_0_segment_0_in_plane_joint_p_control', 'BrittleStarMorphology/arm_0_segment_0_out_of_plane_joint_p_control', 'BrittleStarMorphology/arm_0_segment_1_in

In [14]:
print(state.observations)
sensors = [key for key in state.observations.keys()]
sensors_with_dim = {s: len(d) for s, d in state.observations.items()}
print(sensors)
print(sensors_with_dim)
print(state.observations["in_plane_joint_position"])

sensor_selection = ['in_plane_joint_position', 'out_of_plane_joint_position', 'segment_contact']
sensor_selection_dim = sum([len(state.observations[sensor]) for sensor in sensor_selection])
print(sensor_selection_dim)


{'in_plane_joint_position': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), 'out_of_plane_joint_position': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), 'in_plane_joint_velocity': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), 'out_of_plane_joint_velocity': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), 'segment_contact': Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32), 'disk_position': Array([0.  , 0.  , 0.11], dtype=float32), 'disk_rotation': Array([-3.142,  0.   ,  0.   ], dtype=float32), 'disk_linear_velocity': Array([0., 0., 0.], dtype=float32), 'disk_angular_velocity': Array([0., 0., 0.], dtype=float32)}
['in_plane_joint_position', 'out_of_plane_joint_position', 'in_plane_joint_velocity', 'out_of_plane_joint_velocity', 'segment_contact', 'disk_position', 'disk_rotation', 'disk_linear_velocity', 'disk_angular_velocity']
{'in_plane_joint_position': 10, 'out_of_plane_joint_position': 10, 'in_plane_joint_velocity'

Let us base the development of the first controller on a brittle star with 2 arms and 5 segments, 10 segments in total, 20 degrees of freedom, meaning 20 joint angle positions to be controlled. This means the final layer of our output is 20.
The inputs will be sampled from the Brittle Star Observation space, but let's for now just initialise an input space with 10 random inputs.

### instantiating the neural network

In [15]:
features = [128,128, dofs]
model = ExplicitMLP(features = features)

# initialising the parameters of the model
rng, rng_input, rng_init = jax.random.split(rng, 3)
x = jnp.zeros(sensor_selection_dim) # just required for the initialisation: only length matters, not input values
print(x)
params = model.init(rng_init, x)
# params is a PyTree --> see jax documentation
# print(params)
print(features)
print(params['params'].keys())
print(jax.tree_util.tree_map(lambda x: x.shape, params))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[128, 128, 20]
dict_keys(['layers_0', 'layers_1', 'layers_2'])
{'params': {'layers_0': {'bias': (128,), 'kernel': (30, 128)}, 'layers_1': {'bias': (128,), 'kernel': (128, 128)}, 'layers_2': {'bias': (20,), 'kernel': (128, 20)}}}
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (128,), 'kernel': (30, 128)}, 'layers_1': {'bias': (128,), 'kernel': (128, 128)}, 'layers_2': {'bias': (20,), 'kernel': (128, 20)}}}


A single forward pass through the model

In [16]:
model.apply(params,x)

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

Indeed, all outputs are distributed between -1 and +1, as expected from the tanh activation fucntion of the output layer

In [17]:
print(30*128+128*128+128*20+128+128+20)
param_reshaper = ParameterReshaper(params)
num_params = param_reshaper.total_params # get from the weights and biases of the NN

23060
ParameterReshaper: 23060 parameters detected for optimization.


### Instantiating the search strategy

In [56]:
# instantiate the search strategy
es_popsize = 2
rng, rng_ask, rng_init = jax.random.split(rng, 3)
strategy  = OpenES(popsize = es_popsize, num_dims = num_params)
# still parameters that can be finetuned, like optimisation method, lrate, lrate decay, ...
es_params = strategy.default_params
# # replacing certain parameters:
# es_params = es_params.replace(init_min = -3, init_max = 3)
print(es_params, '\n')

es_state = strategy.initialize(rng_init, es_params)

candidate, es_state = strategy.ask(rng_ask, es_state)
print('candidate solution shape: ', candidate.shape, '\n')

# network parameters have an additional dimension of 100, meaning that every bias and every weight appears 100 times
# 100 parallel trees if you want
net_params = param_reshaper.reshape(candidate)
print(jax.tree_util.tree_map(lambda x: x.shape, net_params))

EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=1.0, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, beta_3=None, eps=1e-08, max_speed=None), sigma_init=0.03, sigma_decay=1.0, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38) 

candidate solution shape:  (2, 23060) 

{'params': {'layers_0': {'bias': (2, 128), 'kernel': (2, 30, 128)}, 'layers_1': {'bias': (2, 128), 'kernel': (2, 128, 128)}, 'layers_2': {'bias': (2, 20), 'kernel': (2, 128, 20)}}}


In [57]:
print(x, x.shape)

[[ 0.011  0.032 -0.033 ... -0.026 -0.014 -0.009]
 [ 0.03  -0.004  0.006 ...  0.006 -0.024 -0.029]
 [-0.003  0.017  0.085 ... -0.031 -0.017 -0.017]
 ...
 [ 0.058 -0.036 -0.039 ... -0.059  0.006  0.012]
 [ 0.002 -0.007 -0.032 ...  0.015  0.029  0.016]
 [ 0.042  0.017  0.01  ... -0.008  0.061 -0.044]] (100, 23060)


In [58]:
rng, rng_uniform = jax.random.split(rng, 2)
inputs = jax.random.uniform(rng, (100, sensor_selection_dim))
print(inputs.shape)
print(inputs)

(100, 30)
[[0.629 0.024 0.31  ... 0.269 0.925 0.782]
 [0.7   0.646 0.871 ... 0.861 0.453 0.046]
 [0.6   0.698 0.466 ... 0.885 0.135 0.378]
 ...
 [0.561 0.677 0.468 ... 0.862 0.742 0.074]
 [0.047 0.678 0.602 ... 0.261 0.956 0.605]
 [0.483 0.924 0.576 ... 0.015 0.79  0.226]]


In [59]:
outputs = jax.vmap(model.apply)(net_params, inputs)
print(outputs.shape)

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (6 of them) had size 2, e.g. axis 0 of argument variables['params']['layers_0']['bias'] of type float32[2,128];
  * one axis had size 100: axis 0 of argument args[0] of type float32[100,30]

In [60]:
NUM_MJX_ENVIRONMENTS = es_popsize
num_generations = 50 # 2500
# print_every_k_gens = 100 --> replace by wandb monitoring

# generate 100 parallel environments
# set this to the number of CUDA cores that you have available

mjx_vectorized_env = create_environment(
                morphology_specification=morphology_specification,
                arena_configuration=arena_configuration,
                environment_configuration=environment_configuration,
                backend="MJX"
                )

mjx_action_rng, mjx_vectorized_env_rng = jax.random.split(jax.random.PRNGKey(0), 2)
mjx_vectorized_env_rng = jnp.array(jax.random.split(mjx_vectorized_env_rng, NUM_MJX_ENVIRONMENTS))

mjx_vectorized_step = jax.jit(jax.vmap(mjx_vectorized_env.step))
mjx_vectorized_reset = jax.jit(jax.vmap(mjx_vectorized_env.reset))
mjx_vectorized_action_sample = jax.jit(jax.vmap(mjx_vectorized_env.action_space.sample))

mjx_vectorized_state = mjx_vectorized_reset(rng=mjx_vectorized_env_rng)

In [61]:
vectorized_model_apply = jax.jit(jax.vmap(model.apply))

In [62]:
print(f"\t{mjx_vectorized_env.observation_space}")
print(sensor_selection)

sensory_input_nn = []
for sensor in sensor_selection:
    print(mjx_vectorized_state.observations[sensor].shape)
    sensory_input_nn.append(mjx_vectorized_state.observations[sensor])

print(jnp.array(sensory_input_nn).shape)
sensory_input_nn = jnp.concatenate(sensory_input_nn, axis = 1)
print(jnp.array(sensory_input_nn).shape)

	Dict('in_plane_joint_position': Box(-0.5235988, 0.5235988, (10,), <class 'jax.numpy.float32'>), 'out_of_plane_joint_position': Box(-0.5235988, 0.5235988, (10,), <class 'jax.numpy.float32'>), 'in_plane_joint_velocity': Box(-inf, inf, (10,), <class 'jax.numpy.float32'>), 'out_of_plane_joint_velocity': Box(-inf, inf, (10,), <class 'jax.numpy.float32'>), 'segment_contact': Box(0.0, 1.0, (10,), <class 'jax.numpy.float32'>), 'disk_position': Box(-inf, inf, (3,), <class 'jax.numpy.float32'>), 'disk_rotation': Box(-3.1415927, 3.1415927, (3,), <class 'jax.numpy.float32'>), 'disk_linear_velocity': Box(-inf, inf, (3,), <class 'jax.numpy.float32'>), 'disk_angular_velocity': Box(-inf, inf, (3,), <class 'jax.numpy.float32'>))
['in_plane_joint_position', 'out_of_plane_joint_position', 'segment_contact']
(2, 10)
(2, 10)
(2, 10)
(3, 2, 10)
(2, 30)


In [64]:
print(mjx_vectorized_state.reward)

[0 0]


In [67]:
# Run ask-eval-tell loop - NOTE: By default minimization!
for gen in range(num_generations):
    print('generation: ', gen)
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    
    x, state = strategy.ask(rng_gen, es_state, es_params)
    x_shaped = param_reshaper.reshape(x) # --> stacked array

    sensory_input_nn = []
    for sensor in sensor_selection:
        sensory_input_nn.append(mjx_vectorized_state.observations[sensor])
    sensory_input_nn = jnp.concatenate(sensory_input_nn, axis = 1)
    
    actions = vectorized_model_apply(x_shaped, sensory_input_nn)
   
    i = 0
    while not jnp.any(mjx_vectorized_state.terminated | mjx_vectorized_state.truncated):
        i += 1
        print(i, end = " ")
        mjx_vectorized_state = mjx_vectorized_step(state=mjx_vectorized_state, action=action)

    


#     fitness = ...  # Your population evaluation fct 
#     # fitness should be an array with population size as len (e.g. 100)
#     state = strategy.tell(x, fitness, state, es_params)

# # Get best overall population member & its fitness
# state.best_member, state.best_fitness

generation:  0
candidate params:  (2, 23060) [[ 0.026  0.021 -0.032 ...  0.055  0.002 -0.01 ]
 [-0.026 -0.021  0.032 ... -0.055 -0.002  0.01 ]]


In [23]:
rng, env_rng, action_rng = jax.random.split(rng, 3)

jit_step = jax.jit(env.step)
jit_reset = jax.jit(env.reset)

env_state = jit_reset(rng=env_rng)

env_frames = []
while not (env_state.terminated | env_state.truncated):
    action_rng, sub_rng = jax.random.split(action_rng, 2)
    action = env.action_space.sample(rng=sub_rng)

    env_state = jit_step(state=env_state, action=action)
    env_frames.append(post_render(env.render(state=env_state), env.environment_configuration))
show_video(images=env_frames)



0
This browser does not support the video tag.


In [25]:
env.close()  # always close the environment after using it!

In [26]:
NUM_MJX_ENVIRONMENTS = 100  # set this to the number of CUDA cores that you have available

mjx_vectorized_env = create_environment(
                morphology_specification=morphology_specification,
                arena_configuration=arena_configuration,
                environment_configuration=environment_configuration,
                backend="MJX"
                )

mjx_action_rng, mjx_vectorized_env_rng = jax.random.split(jax.random.PRNGKey(0), 2)
mjx_vectorized_env_rng = jnp.array(jax.random.split(mjx_vectorized_env_rng, NUM_MJX_ENVIRONMENTS))

mjx_vectorized_step = jax.jit(jax.vmap(mjx_vectorized_env.step))
mjx_vectorized_reset = jax.jit(jax.vmap(mjx_vectorized_env.reset))
mjx_vectorized_action_sample = jax.jit(jax.vmap(mjx_vectorized_env.action_space.sample))

In [None]:
mjx_vectorized_state = mjx_vectorized_reset(rng=mjx_vectorized_env_rng)

mjx_frames = []
i = 0
while not jnp.any(mjx_vectorized_state.terminated | mjx_vectorized_state.truncated):
    i += 1
    print(i)
    mjx_action_rng, *sub_rngs = jnp.array(jax.random.split(mjx_action_rng, NUM_MJX_ENVIRONMENTS + 1))
    action = mjx_vectorized_action_sample(rng=jnp.array(sub_rngs))
    
    mjx_vectorized_state = mjx_vectorized_step(state=mjx_vectorized_state, action=action)
    mjx_frames.append(
            post_render(
                mjx_vectorized_env.render(state=mjx_vectorized_state),
                mjx_vectorized_env.environment_configuration
                )
            )
show_video(images=mjx_frames)

In [1]:
print("Observation space:")
print(f"\t{mjx_vectorized_env.observation_space}")
print("Action space:")
print(f"\t{mjx_vectorized_env.action_space}")

NameError: name 'mjx_vectorized_env' is not defined

In [None]:
mjx_vectorized_env.close()