# Imports set up rendering

In [None]:
#@title Check if MuJoCo installation was successful

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
# %env MUJOCO_GL=egl
os.environ['MUJOCO_GL'] = 'egl'
os.environ['PYOPENGL_PLATFORM'] = 'egl'
try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=True "
    # "--xla_gpu_enable_async_collectives=true "
    # "--xla_gpu_enable_latency_hiding_scheduler=true "
    # "--xla_gpu_enable_highest_priority_async_stream=true "
)
# os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Use GPU 1
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_triton_gemm_any=True "
)

In [None]:
%load_ext autoreload
%autoreload 2
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use GPU 1
import functools
import jax
# jax.config.update("jax_enable_x64", True)

n_gpus = jax.device_count(backend="gpu")
print(f"Using {n_gpus} GPUs")
from typing import Dict
from brax import envs
import mujoco
import pickle
import warnings
import mediapy as media
import hydra
import jax.numpy as jp

from omegaconf import DictConfig, OmegaConf
from brax.training.agents.ppo import networks as ppo_networks
from custom_brax import custom_ppo as ppo
from custom_brax import custom_wrappers
from custom_brax import custom_ppo_networks
from orbax import checkpoint as ocp
from flax.training import orbax_utils
from preprocessing.mjx_preprocess import process_clip_to_train
from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking
from utils.utils import *
from utils.fly_logging import log_eval_rollout

warnings.filterwarnings("ignore", category=DeprecationWarning)
# jax.config.update("jax_enable_x64", True)

from hydra import initialize, compose
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra


# Load configs

In [None]:
base_dir ='/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt'
run_cfg_list = sorted(list(Path(base_dir).rglob('run_config.yaml')))
for n, run_cfg in enumerate(run_cfg_list):
    print(n, run_cfg)


cfg_num = 3
cfg = OmegaConf.load(run_cfg_list[cfg_num])
run_id = int(run_cfg_list[cfg_num].parent.parent.stem.split('=')[1])

In [4]:
dataset = "multiclip"
with initialize(version_base=None, config_path="configs"):
    cfg_temp=compose(config_name='config.yaml',overrides= [f"dataset=fly_{dataset}", f"train=train_fly_{dataset}", "paths=walle", "version=ckpt", f'run_id={run_id}'],return_hydra_config=True,)
    HydraConfig.instance().set_config(cfg_temp)
    
cfg.paths = cfg_temp.paths

In [5]:
for k in cfg.paths.keys():
    if (k != 'user'):
        cfg.paths[k] = Path(cfg.paths[k])
        cfg.paths[k].mkdir(parents=True, exist_ok=True)
env_cfg = cfg.dataset
env_args = cfg.dataset.env_args

cfg.paths.base_dir = cfg.paths.base_dir.parent / 'ckpt'
reference_path = cfg.paths.data_dir/ f"clips/all_clips_batch_interp.p"
# reference_path = cfg.paths.data_dir/ f"clips/{cfg.dataset['clip_idx']}"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)


# Load env

In [None]:
from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking, FlyRunSim, _bounded_quat_dist
# dataset = 'multiclip'

# with initialize(version_base=None, config_path="configs"):
#     cfg=compose(config_name='config.yaml',overrides= [f"dataset=fly_{dataset}", f"train=train_fly_{dataset}", "paths=walle"],return_hydra_config=True,)
#     HydraConfig.instance().set_config(cfg)


# env_args = cfg.dataset.env_args
envs.register_environment("fly_freejnt_clip", FlyTracking)
envs.register_environment("fly_freejnt_multiclip", FlyMultiClipTracking)
envs.register_environment("fly_run_policy", FlyRunSim)
# cfg.dataset.env_args.mjcf_path = '/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_viz_only.xml'
print(cfg.train.env_name)
env = envs.get_environment(
    cfg.train.env_name,
    reference_clip=reference_clip,
    **cfg.dataset.env_args,
)

In [None]:
dataset = "multiclip"
with initialize(version_base=None, config_path="configs"):
    cfg=compose(config_name='config.yaml',overrides= [f"dataset=fly_{dataset}", f"train=train_fly_{dataset}", "paths=walle"],return_hydra_config=True,)
    HydraConfig.instance().set_config(cfg)
    
for k in cfg.paths.keys():
    if (k != 'user'):
        cfg.paths[k] = Path(cfg.paths[k])
        cfg.paths[k].mkdir(parents=True, exist_ok=True)
env_cfg = cfg.dataset
env_args = cfg.dataset.env_args
reference_path = cfg.paths.data_dir/ f"clips/all_clips_batch_interp.p"
# reference_path = cfg.paths.data_dir/ f"clips/0.p"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)


In [None]:
dataset = 'multiclip'

from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking, FlyRunSim, _bounded_quat_dist
with initialize(version_base=None, config_path="configs"):
    cfg=compose(config_name='config.yaml',overrides= [f"dataset=fly_{dataset}", f"train=train_fly_{dataset}", "paths=walle"],return_hydra_config=True,)
    HydraConfig.instance().set_config(cfg)
    

env_args = cfg.dataset.env_args
print(cfg.train.env_name)
envs.register_environment("fly_freejnt_clip", FlyTracking)
envs.register_environment("fly_freejnt_multiclip", FlyMultiClipTracking)
envs.register_environment("fly_run_policy", FlyRunSim)
# cfg_load = OmegaConf.load('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21377073//logs/run_config.yaml')
# cfg_load = OmegaConf.load('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21356039/logs/run_config.yaml')
cfg_load = OmegaConf.load('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21402781/logs/run_config.yaml')
cfg_load.paths = cfg.paths
env = envs.get_environment(
    cfg.train.env_name,
    reference_clip=reference_clip,
    **cfg.dataset.env_args,
)

In [7]:
rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
    
# rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
# define the jit reset/step functions
jit_reset = jax.jit(rollout_env.reset)
jit_step = jax.jit(rollout_env.step)
state = jit_reset(jax.random.PRNGKey(0))


# Test inference

In [8]:
from orbax import checkpoint as ocp
from flax.training import orbax_utils
import optax

In [None]:
# model_path = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21390629/ckpt/21390629')
# model_path = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21377073/ckpt/21377073')
# model_path = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21356039/ckpt/21356039')
model_path = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21402781/ckpt/21402781')
# model_path = Path(cfg.paths.ckpt_dir/f"{run_id}")
##### Get all the checkpoint files #####
ckpt_files = sorted([Path(f.path) for f in os.scandir(model_path) if f.is_dir()])
max_ckpt = ckpt_files[-1].as_posix()
env_args = cfg.dataset.env_args
print(max_ckpt)

In [None]:
def policy_params_fn(num_steps, make_policy, params, policy_params_fn_key, model_path=model_path):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/debug/Test_paths') / f'{num_steps}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)
  
  
episode_length = (env_args.clip_length - 50 - env_cfg.ref_traj_length) * env._steps_for_cur_frame
print(f"episode_length {episode_length}")

train_fn = functools.partial(
            ppo.train,
            num_envs=cfg.train["num_envs"],
            num_timesteps=0,
            num_evals=1,
            num_resets_per_eval=cfg.train['num_resets_per_eval'],
            reward_scaling=cfg.train['reward_scaling'],
            episode_length=episode_length,
            normalize_observations=True,
            action_repeat=cfg.train['action_repeat'],
            clipping_epsilon=cfg.train["clipping_epsilon"],
            unroll_length=cfg.train['unroll_length'],
            num_minibatches=cfg.train["num_minibatches"],
            num_updates_per_batch=cfg.train["num_updates_per_batch"],
            discounting=cfg.train['discounting'],
            learning_rate=cfg.train["learning_rate"],
            kl_weight=cfg.train["kl_weight"],
            entropy_cost=cfg.train['entropy_cost'],
            batch_size=cfg.train["batch_size"],
            seed=cfg.train['seed'],
            network_factory=functools.partial(
                custom_ppo_networks.make_intention_ppo_networks,
                encoder_hidden_layer_sizes=cfg.train['encoder_hidden_layer_sizes'],
                decoder_hidden_layer_sizes=cfg.train['decoder_hidden_layer_sizes'],
                value_hidden_layer_sizes=cfg.train['value_hidden_layer_sizes'],
            ),
            restore_checkpoint_path=max_ckpt,
        )

make_inference_fn, params, _= train_fn(environment=env,)
params2 = (params[0],params[1].policy)
policy_params = (params[0],params[1].policy)
# Env_steps = params[2]
jit_inference_fn = jax.jit(make_inference_fn(policy_params, deterministic=True))

In [None]:
state.info['clip_idx'] = 1
print(state.info['clip_idx'])
# reset_rng, act_rng = jax.random.split(policy_params_fn_key)
rng = jax.random.PRNGKey(0)
reset_rng, act_rng = jax.random.split(rng)
state = jit_reset(reset_rng)

rollout = [state]
# rollout_len = env_args["clip_length"]*int(rollout_env._steps_for_cur_frame)
rollout_len = 1000
ctrl_all = []
for i in range(rollout_len):
    _, act_rng = jax.random.split(act_rng)
    obs = state.obs
    ctrl, extras = jit_inference_fn(obs, act_rng)
    state = jit_step(state, ctrl)
    ctrl_all.append(ctrl.copy())
    rollout.append(state)
rollout2 = [state.pipeline_state for state in rollout]


In [None]:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

pixels = rollout_env.render(rollout2, camera='track1', width=480, height=480, scene_option=scene_option)
media.show_video(pixels,fps=50)

In [None]:
qposes_rollout = np.array([state.pipeline_state.qpos for state in rollout])

ref_traj = env._get_reference_clip(rollout[0].info)
print(f"clip_id:{rollout[0].info['clip_idx']}")
qposes_ref = np.repeat(
    np.hstack([ref_traj.position, ref_traj.quaternion, ref_traj.joints]),
    env._steps_for_cur_frame,
    axis=0,
)
spec = mujoco.MjSpec()
spec = spec.from_file(cfg.dataset.rendering_mjcf)
mj_model = spec.compile()
print(cfg.dataset.rendering_mjcf)
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = env.sys.mj_model.opt.timestep

mj_data = mujoco.MjData(mj_model)
site_id = [
    mj_model.site(i).id
    for i in range(mj_model.nsite)
    if "-0" in mj_model.site(i).name
]
for id in site_id:
    mj_model.site(id).rgba = [1, 0, 0, 1]

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

# save rendering and log to wandb
os.environ["MUJOCO_GL"] = "osmesa"
mujoco.mj_kinematics(mj_model, mj_data)
# renderer = mujoco.Renderer(mj_model, height=512, width=512)

frames = []
# render while stepping using mujoco
with mujoco.Renderer(mj_model, height=480, width=480) as renderer:
    for qpos1, qpos2 in zip(qposes_rollout, qposes_ref):
        mj_data.qpos = np.append(qpos1, qpos2)
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera=1, scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)

In [None]:
media.show_video(frames,fps=50)

In [None]:
for n, joint in enumerate(joint_names):
   print(f'{n}: {joint}')

In [None]:
joint_idxs = env._joint_idxs
joint_names = cfg.dataset.env_args.joint_names

joints_policy = qposes_rollout[:, joint_idxs]
joints_ref = qposes_ref[:, joint_idxs]

fig, axs = plt.subplots(2, 1, figsize=(10, 10))
# for i in range(joints_policy.shape[1]):
i = 10
axs[0].plot(joints_ref[:, i], label=joint_names[i])
axs[0].plot(joints_policy[:, i], label=joint_names[i])

i=12
axs[1].plot(joints_ref[:, i], label=joint_names[i])
axs[1].plot(joints_policy[:, i], label=joint_names[i])


In [None]:
ctrl_all = jp.array(ctrl_all)
model_path = "/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fast_amp.xml"
spec = mujoco.MjSpec()
spec = spec.from_file(model_path)
thorax = spec.find_body("thorax")
first_joint = thorax.first_joint()
# first_joint.delete()
root = spec.compile()
root.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}[cfg.dataset.env_args.solver.lower()]
root.opt.iterations = env_args.iterations
root.opt.ls_iterations = env_args.ls_iterations
root.opt.timestep = env_args.physics_timestep
root.opt.jacobian = 0
data = mujoco.MjData(root)
data.qpos = qposes_rollout[0]
mujoco.mj_forward(root, data)
n_frames = ctrl_all.shape[0]
height = 512
width = 512
frames = []
fps = 1/env.dt
times = []
sensordata = []
scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]

scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

qpos_all,rollout,ncon_all = [],[],[]
with mujoco.Renderer(root, height, width) as renderer:
    for i in range(n_frames):
        data.ctrl = ctrl_all[i].copy()
        # data.qpos = qposes_rollout[i].copy()
        while data.time < i/fps:
            mujoco.mj_step(root, data)
        sensordata.append(data.sensordata.copy())
        times.append(data.time)
        renderer.update_scene(data,camera='track1',scene_option=scene_option)
        frame = renderer.render()
        frames.append(frame)
        qpos_all.append(data.qpos.copy())
        ncon_all.append(data.ncon)
        rollout.append(data)

media.show_video(frames, fps=50)


In [17]:
# 1 (mg cm) / (s^2) = 1.0 × 10-8 newtons

In [None]:
torque_data[f'clip{n:02d}'].keys()

In [None]:
torque_data = ioh5.load('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Torque_control_multiclip.h5')
mclip_sensor = [torque_data[f'clip{n:02d}']['sensordata'] for n in range(17)]

In [226]:
sensordata = [state.pipeline_state.sensordata for state in rollout]
sensordata = np.stack(sensordata)#[:,:-5]
touch_sensors = sensordata#[:, -5:]

In [None]:
plt.plot(np.max(np.abs(sensordata), axis=0))

In [None]:
plt.plot(np.max(np.abs(sensordata), axis=0))

In [None]:
N = 5
sdata = (sensordata.reshape(sensordata.shape[0],-1,3)) # Time x end_eff x xyz, x=forward 1e-8*
sdata = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=0, arr=sdata)
end_eff = [
# 'claw_T1_left',
'claw_T1_right',
'claw_T2_left',
'claw_T2_right',
'claw_T3_left',
'claw_T3_right',
]
t = 100
dt = 100
fig, axs = plt.subplots(3, 2, figsize=(10, 10))
axs = axs.flatten()
ax = axs[0]
t_frame = 45
ax.imshow(pixels[t+t_frame])
ax.axis('off')
for n in range(0,len(end_eff)):
    ax = axs[n+1]
    ax.plot(sdata[t:t+dt,n,0])
    ax.plot(sdata[t:t+dt,n,1])
    ax.plot(sdata[t:t+dt,n,2])
    ax.axvline(t_frame, color='r')
# plt.plot(sdata[:,:,2])

In [None]:
plt.plot(touch_sensors[t:dt,1])

In [None]:
N = 50
sdata = 1e-8*(np.stack(sensordata).reshape(-1,6,3)) # Time x end_eff x xyz, x=forward
sdata = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=0, arr=sdata)

fig, axs = plt.subplots(3, 2, figsize=(10, 10), sharey=True)
axs = axs.flatten()
for n in range(len(end_eff)):
    ax = axs[n]
    ax.plot(sdata[:,n,0])
    ax.plot(sdata[:,n,1])
    ax.plot(sdata[:,n,2])
# plt.plot(sdata[:,:,2])