In [46]:
%load_ext autoreload
%autoreload 2

import jax
from brax import envs
from Rodent_Env_Brax import Rodent
import pickle
from preprocessing.mjx_preprocess import process_clip_to_train
import pickle
import os


config = {
    "env_name": "rodent",
    "algo_name": "ppo",
    "task_name": "run",
    "num_envs": 16,
    "num_timesteps": 500_000_000,
    "eval_every": 5_000_000,
    "episode_length": 200,
    "batch_size": 16,
    "learning_rate": 1e-4,
    "physics_steps_per_control_step": 5,
    "too_far_dist": 0.1,
    "ctrl_cost_weight": 0.01,
    "pos_reward_weight": 100.0,
    "quat_reward_weight": 3.0,
    "healthy_reward": 0.25,
    "healthy_z_range": (0.035, 0.5),
    "terminate_when_unhealthy": True,
    "run_platform": "Harvard",
    "solver": "cg",
    "iterations": 7,
    "ls_iterations": 7,
}

envs.register_environment("rodent", Rodent)

reference_path = f"clips/84.p"

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [47]:
# instantiate the environment
env_name = config["env_name"]
env = envs.get_environment(
    env_name,
    track_pos=reference_clip.position,
    track_quat=reference_clip.quaternion,
    terminate_when_unhealthy=config["terminate_when_unhealthy"],
    solver=config["solver"],
    iterations=config["iterations"],
    ls_iterations=config["ls_iterations"],
    too_far_dist=config["too_far_dist"],
    ctrl_cost_weight=config["ctrl_cost_weight"],
    pos_reward_weight=config["pos_reward_weight"],
    quat_reward_weight=config["quat_reward_weight"],
    healthy_reward=config["healthy_reward"],
    healthy_z_range=config["healthy_z_range"],
    physics_steps_per_control_step=config["physics_steps_per_control_step"],
)

self._steps_for_cur_frame: 2.0


In [48]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [49]:
key = jax.random.PRNGKey(0)
state = jit_reset(key)

In [50]:
_, key = jax.random.split(key)
state = jit_step(state, jax.random.normal(key, (env.sys.nu,)))

In [51]:
state.info

{'cur_frame': Array(8, dtype=int32),
 'steps_taken_cur_frame': Array(1, dtype=int32, weak_type=True),
 'summed_pos_distance': Array(2.465989e-05, dtype=float32)}

In [52]:
state.obs.shape

(167,)

In [44]:
state.pipeline_state.contact.dist

Array([ 0.02668312,  0.00963982,  0.00365678,  0.07354705,  0.04750727,
        0.04865877,  0.03326368,  0.0376247 ,  0.01900179,  0.01720116,
        0.00490876,  0.01786475,  0.00469432,  0.01842165,  0.00612925,
        0.07880929,  0.05621161,  0.05623355,  0.05488762,  0.04300088,
        0.0606983 ,  0.02790296,  0.04077419,  0.02813101,  0.0419216 ,
        0.03007138,  0.0429426 ,  0.07854627,  0.07861307,  0.02748809,
        0.00816264,  0.00640287,  0.01176222,  0.00704252,  0.00988688,
        0.01540634,  0.02137755,  0.01587353,  0.02297394,  0.01571755,
        0.02194083,  0.02498014, -0.00109657,  0.00111527,  0.00263967,
        0.0002849 ,  0.00068986,  0.01348399,  0.00817339,  0.01384809,
        0.00800651,  0.01211086,  0.00725734], dtype=float32)