# Imports set up rendering

In [None]:
%load_ext autoreload
%autoreload 2
import distutils.util
import os
import subprocess
os.environ['MUJOCO_GL'] = 'egl'
os.environ['PYOPENGL_PLATFORM'] = 'egl'
# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=True "
    # "--xla_gpu_enable_async_collectives=true "
    # "--xla_gpu_enable_latency_hiding_scheduler=true "
    # "--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ['CUDA_VISIBLE_DEVICES'] = '1' # Use GPU 1
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_triton_gemm_any=True "
)

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use GPU 1
import functools
import jax
# jax.config.update("jax_enable_x64", True)

n_gpus = jax.device_count(backend="gpu")
print(f"Using {n_gpus} GPUs")
from typing import Dict
from brax import envs
import mujoco
import pickle
import warnings
import mediapy as media
import hydra
import jax.numpy as jp

from tqdm.auto import tqdm
from omegaconf import DictConfig, OmegaConf
from brax.training.agents.ppo import networks as ppo_networks
from brax import math as brax_math
from custom_brax import custom_ppo as ppo
from custom_brax import custom_wrappers
from custom_brax import network_masks as masks
from custom_brax import custom_ppo_networks
from orbax import checkpoint as ocp
from flax.training import orbax_utils
from preprocessing.mjx_preprocess import process_clip_to_train
from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking
from utils.utils import *
from utils.fly_logging import log_eval_rollout

warnings.filterwarnings("ignore", category=DeprecationWarning)
# jax.config.update("jax_enable_x64", True)

from hydra import initialize, compose
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra


##### Plotting settings ######
import matplotlib as mpl
mpl.rcParams.update({'font.size':          10,
                     'axes.linewidth':     2,
                     'xtick.major.size':   5,
                     'ytick.major.size':   5,
                     'xtick.major.width':  2,
                     'ytick.major.width':  2,
                     'axes.spines.right':  False,
                     'axes.spines.top':    False,
                     'pdf.fonttype':       42,
                     'xtick.labelsize':    10,
                     'ytick.labelsize':    10,
                     'figure.facecolor':   'white',
                     'pdf.use14corefonts': True,
                     'svg.fonttype':       'none',
                     'font.family':        'sans-serif',
                    #  'font.family':        'Arial',
                    #  'font.sans-serif':    'Arial',
                     'font.serif':         'Arial',
                    })

In [None]:
import subprocess as sp

def get_gpu_memory():
    """Get total GPU memory with nvidia-smi

    Returns:
        list: total memory in MB for each GPU
    """
    command = "nvidia-smi --query-gpu=memory.total --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    return [int(x.split()[0]) for i, x in enumerate(memory_free_info)]

def closest_power_of_two(x):
    # Start with the largest power of 2 less than or equal to x
    power = 1
    while power * 2 <= x:
        power *= 2
    return power

tot_mem = get_gpu_memory()[0]
num_envs = int(closest_power_of_two(tot_mem/100.4)) #21.4
num_envs

# Load configs

In [None]:
base_dir ='/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt'
# base_dir ='/gscratch/portia/eabe/biomech_model/Flybody/RL_Flybody/ckpt'
run_cfg_list = sorted(list(Path(base_dir).rglob('run_config.yaml')))
for n, run_cfg in enumerate(run_cfg_list):
    print(n, run_cfg)


cfg_num = -1
cfg = OmegaConf.load(run_cfg_list[cfg_num])
run_id = int(run_cfg_list[cfg_num].parent.parent.stem.split('=')[1])
print(cfg.dataset.dname)
fig_dir = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/debug/figures')

In [None]:
dataset = cfg.dataset.dname
with initialize(version_base=None, config_path="configs"):
    cfg_temp=compose(config_name='config.yaml',overrides= [f"dataset={dataset}", f"train=train_{dataset}", "paths=walle", "version=ckpt", f'run_id={run_id}'],return_hydra_config=True,)
    HydraConfig.instance().set_config(cfg_temp)
    
cfg.paths = cfg_temp.paths

In [None]:
for k in cfg.paths.keys():
    if (k != 'user'):
        cfg.paths[k] = Path(cfg.paths[k])
        cfg.paths[k].mkdir(parents=True, exist_ok=True)
env_cfg = cfg.dataset
env_args = cfg.dataset.env_args

cfg.paths.base_dir = cfg.paths.base_dir.parent / 'ckpt'
reference_path = cfg.paths.data_dir/ f"clips/all_clips_turn_interp_small.p"
# reference_path = cfg.paths.data_dir/ f"clips/all_clips_batch_interp.p"
# reference_path = cfg.paths.data_dir/ f"clips/{cfg.dataset['clip_idx']}"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)
ref_data = np.concatenate([reference_clip.position,reference_clip.quaternion,reference_clip.joints], axis=-1)

# Load env

In [None]:
from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking, FlyRunSim, _bounded_quat_dist
# dataset = 'multiclip'

# with initialize(version_base=None, config_path="configs"):
#     cfg=compose(config_name='config.yaml',overrides= [f"dataset=fly_{dataset}", f"train=train_fly_{dataset}", "paths=walle"],return_hydra_config=True,)
#     HydraConfig.instance().set_config(cfg)


# env_args = cfg.dataset.env_args
envs.register_environment("fly_freejnt_clip", FlyTracking)
envs.register_environment("fly_freejnt_multiclip", FlyMultiClipTracking)
envs.register_environment("fly_run_policy", FlyRunSim)
# cfg.dataset.env_args.mjcf_path = '/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_viz_only.xml'
# cfg.dataset.env_args.mjcf_path = '/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fast.xml'
print(cfg.train.env_name)
# cfg.dataset.env_args.iterations = 12
# cfg.dataset.env_args.ls_iterations = 12
env = envs.get_environment(
    cfg.train.env_name,
    reference_clip=reference_clip,
    **cfg.dataset.env_args,
)

In [None]:
rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)

# rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
# define the jit reset/step functions
jit_reset = jax.jit(rollout_env.reset)
jit_step = jax.jit(rollout_env.step)
# state = jit_reset(jax.random.PRNGKey(0))


# Test inference

In [None]:
from orbax import checkpoint as ocp
from flax.training import orbax_utils
import optax
from natsort import natsorted
from custom_brax import custom_networks
import custom_brax.masked_running_statistics as running_statistics


In [None]:
model_path = Path(cfg.paths.ckpt_dir)
# model_path = '/gscratch/portia/eabe/biomech_model/Flybody/RL_Flybody/debug/ckpt'
# max_ckpt = cfg.train.restore_checkpoint
##### Get all the checkpoint files #####
ckpt_files = natsorted([Path(f.path) for f in os.scandir(model_path) if f.is_dir()])
max_ckpt = ckpt_files[-1]
env_args = cfg.dataset.env_args
print(max_ckpt)

In [None]:
def policy_params_fn(num_steps, make_policy, params, policy_params_fn_key, model_path=model_path):
    # save checkpoints
    print(params[1].policy)
#   orbax_checkpointer = ocp.PyTreeCheckpointer()
#   save_args = orbax_utils.save_args_from_target(params)
#   path = Path('/gscratch/portia/eabe/biomech_model/Flybody/RL_Flybody/debug/ckpt') / f'{num_steps}'
#   orbax_checkpointer.save(path, params, force=True, save_args=save_args)

if  ('network_type' in cfg.train) and (cfg.train['network_type'] is not None) and ('encoderdecoder' in cfg.train['network_type']):
    network_type = custom_ppo_networks.make_encoderdecoder_ppo_networks
else: 
    network_type = custom_ppo_networks.make_intention_ppo_networks


options = ocp.CheckpointManagerOptions(save_interval_steps=1)
ckpt_mgr = ocp.CheckpointManager(
    cfg.paths.ckpt_dir,
    item_names=("normalizer_params", "params", "env_steps"),
    options=options,
)
episode_length = (env_args.clip_length - 50 - env_args.ref_len) * env._steps_for_cur_frame
print(f"episode_length {episode_length}")
train_fn = functools.partial(
            ppo.train,
            num_envs=cfg.train["num_envs"],
            num_timesteps=0,
            num_evals=int(cfg.train["num_timesteps"] / cfg.train["eval_every"]),
            num_resets_per_eval=cfg.train['num_resets_per_eval'],
            reward_scaling=cfg.train['reward_scaling'],
            episode_length=episode_length,
            normalize_observations=True,
            action_repeat=cfg.train['action_repeat'],
            clipping_epsilon=cfg.train["clipping_epsilon"],
            unroll_length=cfg.train['unroll_length'],
            num_minibatches=cfg.train["num_minibatches"],
            num_updates_per_batch=cfg.train["num_updates_per_batch"],
            discounting=cfg.train['discounting'],
            learning_rate=cfg.train["learning_rate"],
            kl_weight=cfg.train["kl_weight"],
            entropy_cost=cfg.train['entropy_cost'],
            batch_size=cfg.train["batch_size"],
            seed=cfg.train['seed'],
            network_factory=functools.partial(
                network_type,
                encoder_hidden_layer_sizes=cfg.train['encoder_hidden_layer_sizes'],
                decoder_hidden_layer_sizes=cfg.train['decoder_hidden_layer_sizes'],
                value_hidden_layer_sizes=cfg.train['value_hidden_layer_sizes'],
            ),
            checkpoint_network_factory=functools.partial(
                    custom_ppo_networks.make_intention_ppo_networks,
                    intention_latent_size=60,
                    encoder_hidden_layer_sizes=cfg.train.ckpt_net['encoder_hidden_layer_sizes'],
                    decoder_hidden_layer_sizes=cfg.train.ckpt_net['decoder_hidden_layer_sizes'],
                    value_hidden_layer_sizes=cfg.train.ckpt_net['value_hidden_layer_sizes'],
                ),
            checkpoint_path=max_ckpt,
            freeze_mask_fn=None if (cfg.train['freeze_decoder'] == False) else masks.create_decoder_mask,
            continue_training=True,
            custom_wrap=True,  # custom wrappers to handle infos
        )

make_inference_fn, params, _= train_fn(
    environment=env,            
    policy_params_fn=policy_params_fn,
    checkpoint_manager=ckpt_mgr,)
params2 = (params[0],params[1].policy)
policy_params = (params[0],params[1].policy)
# Env_steps = params[2]
jit_inference_fn = jax.jit(make_inference_fn(policy_params, deterministic=True))

In [None]:
# def policy_params_fn(num_steps, make_policy, params, policy_params_fn_key, model_path=model_path):
#     # save checkpoints
#     print(params[1].policy)
# #   orbax_checkpointer = ocp.PyTreeCheckpointer()
# #   save_args = orbax_utils.save_args_from_target(params)
# #   path = Path('/gscratch/portia/eabe/biomech_model/Flybody/RL_Flybody/debug/ckpt') / f'{num_steps}'
# #   orbax_checkpointer.save(path, params, force=True, save_args=save_args)

# # if  ('network_type' in cfg.train) and (cfg.train['network_type'] is not None) and ('encoderdecoder' in cfg.train['network_type']):
# #     network_type = custom_ppo_networks.make_encoderdecoder_ppo_networks
# # else: 
#     # network_type = custom_ppo_networks.make_intention_ppo_networks
# network_type = custom_ppo_networks.make_randomdecoder_ppo_networks


# options = ocp.CheckpointManagerOptions(save_interval_steps=1)
# ckpt_mgr = ocp.CheckpointManager(
#     cfg.paths.ckpt_dir,
#     item_names=("normalizer_params", "params", "env_steps"),
#     options=options,
# )
# episode_length = (env_args.clip_length - 50 - env_args.ref_len) * env._steps_for_cur_frame
# print(f"episode_length {episode_length}")
# train_fn = functools.partial(
#             ppo.train,
#             num_envs=cfg.train["num_envs"],
#             num_timesteps=0,
#             num_evals=int(cfg.train["num_timesteps"] / cfg.train["eval_every"]),
#             num_resets_per_eval=cfg.train['num_resets_per_eval'],
#             reward_scaling=cfg.train['reward_scaling'],
#             episode_length=episode_length,
#             normalize_observations=True,
#             action_repeat=cfg.train['action_repeat'],
#             clipping_epsilon=cfg.train["clipping_epsilon"],
#             unroll_length=cfg.train['unroll_length'],
#             num_minibatches=cfg.train["num_minibatches"],
#             num_updates_per_batch=cfg.train["num_updates_per_batch"],
#             discounting=cfg.train['discounting'],
#             learning_rate=cfg.train["learning_rate"],
#             kl_weight=cfg.train["kl_weight"],
#             entropy_cost=cfg.train['entropy_cost'],
#             batch_size=cfg.train["batch_size"],
#             seed=cfg.train['seed'],
#             network_factory=functools.partial(
#                 network_type,
#                 encoder_hidden_layer_sizes=cfg.train['encoder_hidden_layer_sizes'],
#                 decoder_hidden_layer_sizes=cfg.train['decoder_hidden_layer_sizes'],
#                 value_hidden_layer_sizes=cfg.train['value_hidden_layer_sizes'],
#             ),
#             checkpoint_network_factory=functools.partial(
#                     custom_ppo_networks.make_intention_ppo_networks,
#                     intention_latent_size=60,
#                     encoder_hidden_layer_sizes=cfg.train.ckpt_net['encoder_hidden_layer_sizes'],
#                     decoder_hidden_layer_sizes=cfg.train.ckpt_net['decoder_hidden_layer_sizes'],
#                     value_hidden_layer_sizes=cfg.train.ckpt_net['value_hidden_layer_sizes'],
#                 ),
#             checkpoint_path=max_ckpt,
#             freeze_mask_fn=None if (cfg.train['freeze_decoder'] == False) else masks.create_decoder_mask,
#             continue_training=True,
#             custom_wrap=True,  # custom wrappers to handle infos
#         )

# make_inference_fn, params, _= train_fn(
#     environment=env,            
#     policy_params_fn=policy_params_fn,
#     checkpoint_manager=ckpt_mgr,)
# params2 = (params[0],params[1].policy)
# policy_params = (params[0],params[1].policy)
# # Env_steps = params[2]
# jit_inference_fn = jax.jit(make_inference_fn(policy_params, deterministic=True))

# Simulate Rollout

In [None]:
vmap_reset = jax.vmap(jit_reset)
vmap_step = jax.vmap(jit_step)
vmap_inference = jax.vmap(jit_inference_fn, in_axes=(0,None))

In [None]:
rollout_data = {'clip{:02d}'.format(n): {} for n in range(env._n_clips)}
# for n in range(env._n_clips):
# n = 2
# reset_rng, act_rng = jax.random.split(policy_params_fn_key)
nclips = 500
rng = jax.random.PRNGKey(0)
reset_rng, act_rng = jax.random.split(rng)
reset_rng = jax.random.split(rng,nclips)
state = vmap_reset(reset_rng)
state.info['clip_idx'] = jp.arange(nclips)
print(state.info['clip_idx'])
rollout = [state]
# rollout_len = env_args["clip_length"]*int(rollout_env._steps_for_cur_frame)
rollout_len = env._clip_length
ctrl_all,extras_all = [],[]
for i in tqdm(range(rollout_len)):
    _, act_rng = jax.random.split(act_rng)
    # state = state.replace(pipeline_state=state.pipeline_state.replace(qpos=new_qpos[i]))
    # new_info = state.info
    # new_info['command'] = jp.array([1,1,0])
    # state = state.replace(info=new_info)
    obs = state.obs
    ctrl, extras = jit_inference_fn(obs, act_rng)
    state = vmap_step(state, ctrl)
    ctrl_all.append(ctrl.copy())
    extras_all.append(extras)
    rollout.append(state)
rollout2 = [state.pipeline_state for state in rollout]


In [None]:
policy_data = {}
policy_data['sensordata'] = jp.stack([state.sensordata for state in rollout2]).transpose(1,0,2)
policy_data['qposes'] = jp.stack([state.qpos for state in rollout2]).transpose(1,0,2)
policy_data['qvels'] = jp.stack([state.qvel for state in rollout2]).transpose(1,0,2)
policy_data['ctrl'] = jp.stack(ctrl_all).transpose(1,0,2)

In [None]:
render_path = Path(cfg.dataset.rendering_mjcf).parent / 'fruitfly_force_pair.xml'   

clip_idx = 125
qposes_rollout = policy_data['qposes'][clip_idx]
qposes_ref = np.concatenate([reference_clip.position,reference_clip.quaternion,reference_clip.joints], axis=-1)[clip_idx]

repeats_per_frame = 1 
##### Create model from xml file #####
spec = mujoco.MjSpec()
spec = spec.from_file(render_path.as_posix())
mj_model = spec.compile()
##### Set solver options ##### 
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = env.sys.mj_model.opt.timestep
##### Create mj data #####
mj_data = mujoco.MjData(mj_model)
###### change site colors to show policy and target #####
site_names = [
    mj_model.site(i).name
    for i in range(mj_model.nsite)
    if "-0" in mj_model.site(i).name
]
site_id = [
    mj_model.site(i).id
    for i in range(mj_model.nsite)
    if "-0" in mj_model.site(i).name
]
for id in site_id:
    mj_model.site(id).rgba = [1, 0, 0, 1]

##### Add scene options for rendering #####
scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]

os.environ["MUJOCO_GL"] = "egl"
mujoco.mj_kinematics(mj_model, mj_data)

frames = []
# render while stepping using mujoco
# video_path = f"{model_path}/{num_steps}.mp4"
# with imageio.get_writer(video_path, fps=50) as video:

with mujoco.Renderer(mj_model, height=512, width=512) as renderer:
    for qpos1, qpos2 in zip(qposes_rollout, qposes_ref):
        mj_data.qpos = np.append(qpos1, qpos2)
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera=1, scene_option=scene_option)
        pixels = renderer.render()
        # video.append_data(pixels)
        frames.append(pixels)
        
media.show_video(frames, fps=25)


In [None]:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
clip_idx = 125
rollout_data = jax.tree_map(lambda x: x[clip_idx], rollout2)
pixels = rollout_env.render(rollout_data, camera='track3', width=480, height=480, scene_option=scene_option)
media.show_video(pixels,fps=50)

In [None]:
# stacked_data = {}
# for key2 in rollout_data['clip00'].keys():
#     stacked_data[key2] = np.stack([rollout_data['clip{:02d}'.format(n)][key2] for n in range(env._n_clips)])
# ioh5.save('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/amp_control_clip.h5', amp_data)
ioh5.save(f'/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Torque_control_nclips{nclips}_{run_id}.h5', policy_data)


In [None]:

def find_onsets_offsets(signal, threshold=0.5):
    """
    Finds the onsets and offsets of a square signal.

    Args:
        signal: The input signal (numpy array).
        threshold: The threshold for determining onsets and offsets (default: 0.5).

    Returns:
        onsets: A list of indices corresponding to the onsets.
        offsets: A list of indices corresponding to the offsets.
    """

    onsets = []
    offsets = []

    state = 0  # 0: low, 1: high

    for i, value in enumerate(signal):
        if state == 0 and value > threshold:
            onsets.append(i)
            state = 1
        elif state == 1 and value < threshold:
            offsets.append(i)
            state = 0

    return onsets, offsets


In [None]:
policy_data = ioh5.load(f'/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Torque_control_multiclip_{run_id}.h5')


In [None]:

# amp_force = amp_data['sensordata'].reshape(1,amp_data['sensordata'].shape[1],-1,3)
ctrl_force = policy_data['sensordata'].reshape(17,policy_data['sensordata'].shape[1],-1,3)
# amp_force = np.array(amp_force)
ctrl_force = np.array(ctrl_force)
joint_force = ctrl_force[:,:,6:,:]

# amp_force[np.abs(amp_force)<1e-4]=0
# amp_force_clip = amp_force[:,:501].copy()
ctrl_force[np.abs(ctrl_force)<1e-4]=0
ctrl_force_clip = ctrl_force.copy()
# amp_force_clip.shape, ctrl_force_clip.shape # nclips x timesteps x end_eff x xyz
ctrl_force_clip.shape

In [None]:
N = 10
# amp_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='same'), axis=1, arr=amp_force_clip)
# amp_binary = np.zeros_like(amp_smooth)
# amp_binary[np.abs(amp_smooth)>0.01] = 1
ctrl_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N), mode='same'), axis=1, arr=ctrl_force_clip)
ctrl_binary = np.zeros_like(ctrl_smooth)
ctrl_binary[np.abs(ctrl_smooth)>0.01] = 1

In [None]:

def calculate_peakforces(sdata_binary,sdata):
    onsets_all,offsets_all = [],[]
    for clip_idx in range(sdata_binary.shape[0]):
        clip_onsets, clip_offsets = [],[]
        for end_eff_idx in range(sdata_binary.shape[2]):
            dim_onsets, dim_offsets = [],[]
            for dim in range(sdata_binary.shape[3]):
                onsets, offsets = find_onsets_offsets(sdata_binary[clip_idx,:,end_eff_idx,dim])
                dim_onsets.append(onsets)
                dim_offsets.append(offsets)
            clip_onsets.append(dim_onsets)
            clip_offsets.append(dim_offsets)
        onsets_all.append(clip_onsets)
        offsets_all.append(clip_offsets)
        
    sdata_square = np.zeros_like(sdata)
    sdata_square.shape
    peak_force_all,peak_force_all_std = [],[]
    for clip_idx,(clip_onsets,clip_offsets) in enumerate(zip(onsets_all,offsets_all)):
        peak_force_end_eff,peak_force_end_eff_std = [],[]
        for end_eff_idx,(end_eff_onsets,end_eff_offsets) in enumerate(zip(clip_onsets,clip_offsets)):
            peak_force_dim, peak_force_std = [], []
            for dim,(dim_on,dim_off) in enumerate(zip(end_eff_onsets,end_eff_offsets)):
                peak_force = []
                for on,off in zip(dim_on,dim_off):
                    if off-on<10:
                        # print(on,off)
                        continue
                    else:
                        sign = np.sign(np.mean(sdata[clip_idx,on:off,end_eff_idx,dim]))
                        sdata_square[clip_idx,on:off,end_eff_idx,dim] = sign * np.nanmean(np.abs(sdata[clip_idx,on:off,end_eff_idx,dim]))
                        peak_force.append(np.nanmean(np.abs(sdata[clip_idx,on:off,end_eff_idx,dim])))
                peak_force_dim.append(np.nanmean(peak_force))
                peak_force_std.append(np.nanstd(peak_force))
            peak_force_end_eff.append(peak_force_dim)
            peak_force_end_eff_std.append(peak_force_std)
        peak_force_all.append(peak_force_end_eff)
        peak_force_all_std.append(peak_force_end_eff_std)
    peak_force_all = np.stack(peak_force_all)
    peak_force_all_std = np.stack(peak_force_all_std)
    return peak_force_all, peak_force_all_std, sdata_square

In [None]:
peak_force_ctrl,peak_force_ctrl_std, ctrl_square = calculate_peakforces(ctrl_binary,ctrl_force_clip)
ctrl_mean_peak_force = np.nanmean(peak_force_ctrl,axis=0)
ctrl_std_peak_force = np.nanstd(peak_force_ctrl_std,axis=(0))
# peak_force_amp,peak_force_amp_std, amp_square = calculate_peakforces(amp_binary,amp_force_clip)
# amp_mean_peak_force = np.nanmean(peak_force_amp,axis=0)
# amp_std_peak_force = np.nanstd(peak_force_amp_std,axis=(0))
# amp_std_peak_force.shape,ctrl_std_peak_force.shape

In [None]:
end_eff_ctrl = [
'claw T1left',
'claw T1right',
'claw T2left',
'claw T2right',
'claw T3left',
'claw T3right',
]
end_eff_amp = [
'claw T1 right',
'claw T2 left',
'claw T2 right',
'claw T3 left',
'claw T3 right',
]

In [None]:
# Create a masked array using the binary mask
masked_data_ctrl = np.ma.masked_array(np.abs(ctrl_force_clip), mask=~(ctrl_binary.astype(bool))) 
# masked_data_amp = np.ma.masked_array(np.abs(amp_force_clip), mask=~(amp_binary.astype(bool))) 

# Calculate the average along the specified axis
ctr_mean_avg_force = np.ma.mean(masked_data_ctrl, axis=(0,1))
ctr_mean_std_force = np.ma.std(masked_data_ctrl, axis=(0,1))
# amp_mean_avg_force = np.ma.mean(masked_data_amp, axis=(0,1))
# amp_mean_std_force = np.ma.std(masked_data_amp, axis=(0,1))


In [None]:
clip_idx = 10
end_eff_idx = 1
dim = 1
N = 10
# amp_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=amp_force_clip)
# amp_binary = np.zeros_like(amp_smooth)
# amp_binary[np.abs(amp_smooth)>0.01] = 1
ctrl_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=ctrl_force_clip)
ctrl_binary = np.zeros_like(ctrl_smooth)
ctrl_binary[np.abs(ctrl_smooth)>0.01] = 1
plt.plot(ctrl_force_clip[clip_idx,:,end_eff_idx,dim])
plt.plot(ctrl_smooth[clip_idx,:,end_eff_idx,dim])
plt.plot(ctrl_binary[clip_idx,:,end_eff_idx,dim])
plt.plot(ctrl_square[clip_idx,:,end_eff_idx,dim])

In [None]:
clip_idx = 0
# ground_force_clip = ground_force[clip_idx,:, 0].T
fontsize = 12
fig, axs = plt.subplots(2, 1, figsize=(5, 4))
times = np.linspace(0,1,ctrl_smooth.shape[1])
end_eff_idx = 0
i = 9
ax = axs[0]
ax.plot(times,np.abs(10*ctrl_smooth[clip_idx,:,end_eff_idx+1,1]), label=f'Ctrl')

ax.legend(frameon=False,fontsize=fontsize,loc='upper right',bbox_to_anchor=(.9,1.3),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=2,columnspacing=.1)
# ax.set_yticks([0,20,40])
# ax.set_xticks([0,.5,1])
ax.set_xlabel('time (s)')

In [None]:

fig, axs = plt.subplots(2, 1, figsize=(2.5, 4))
ax = axs[0]
end_eff_idx = 1
# ax.plot(np.linspace(0,1,ctrl_smooth.shape[1]),10*np.abs(ctrl_smooth[clip_idx,:,2,0]), label=f'ctrl', c='r')
ax.plot(np.linspace(0,1,ctrl_square.shape[1]),10*np.abs(ctrl_square[clip_idx,:,end_eff_idx,0]), label=f'ctrl', c='#5fc1ffff')
# ax.plot(np.linspace(0,2,amp_smooth.shape[1]),10*np.abs(amp_smooth[clip_idx,:,2,0]), label=f'amp', c='b')
ax.set_xticks([0,0.5,1])
ax.set_yticks([0,0.5,1])
# ax.axhline(0.7, color='k', linestyle='--')
# ax.fill_between(np.linspace(0,1,ctrl_smooth.shape[1]),.7+.1,.7-.1, color='k', alpha=0.2)
# ax.legend()
ax.set_xlabel('time (s)', fontsize=fontsize)
ax.set_ylabel('R1 average \n force ($\mu$N)', fontsize=fontsize)


In [None]:
joint_idxs = env._joint_idxs
# qposes = np.array([rollout2[t].qpos for t in range(len(rollout2))])
# forces = np.array([rollout2[t].qfrc_constraint + rollout2[t].qfrc_smooth for t in range(len(rollout2))])

In [None]:
joint_forces_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=0, arr=np.abs(joint_force[clip_idx]))
joint_forces_smooth.shape

In [None]:
clip_idx = 1
end_eff_idx = 1
t = 10
dt = 500
wind=15
joints_policy = policy_data['qposes'][clip_idx, t:t+dt, joint_idxs].T
joints_ref = ref_data[clip_idx,t:t+dt, joint_idxs].T
ground_forces =ctrl_square[clip_idx,t:t+dt:]
joint_forces = joint_force[clip_idx,t:t+dt]
joint_forces_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(wind)/wind, mode='same'), axis=0,arr=np.sum(np.abs(joint_force[clip_idx]),axis=-1))[t:t+dt]
# joint_forces = forces[t:t+dt, joint_idxs]
fontsize = 12
times = np.linspace(0,1,joints_policy.shape[0])

fig, axs = plt.subplots(2, 1, figsize=(3, 3))
i = 6
ax = axs[0]
ax.plot(times,np.rad2deg(joints_ref[:, 7+i]),c='k', label=f'Data', lw=2)
ax.set_ylabel('R1 body-coxa \n angle (deg)')
ax.set_yticks([0,10,20])
ax.set_xlabel('time (s)')
# ax.plot(times,np.rad2deg(joints_policy[:, 7+i]),c='#5fc1ffff', label=f'Policy',lw=1)
# ax.legend(frameon=False,fontsize=fontsize,loc='upper left',bbox_to_anchor=(.01,1.3),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=2,columnspacing=.1)
ax = axs[1]
# ax.plot(times,10*np.abs(ground_forces[:,end_eff_idx,2]), label=f'ground reaction force', c='#5fc1ffff')
ax.plot(times,10*np.abs(joint_forces_smooth[:,0]),c='r', label=f'forces',lw=1.5)
# ax.plot(times,10*joint_forces_smooth[:,1],c='r', label='force',lw=1.5)
# ax.plot(times,10*joint_forces_smooth[:,1,0],c='r', label='x',lw=1.5)
# ax.plot(times,10*joint_forces_smooth[:,1,1],c='g', label='y',lw=1.5)
# ax.plot(times,10*joint_forces_smooth[:,1,2],c='b', label='z',lw=1.5)
# ax.legend(frameon=False,fontsize=fontsize,loc='upper left',bbox_to_anchor=(.01,1.1),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=3,columnspacing=.1)
ax.tick_params(axis='y')
ax.set_ylabel('R1 body-coxa \n force ($\mu$N)')
# ax.set_yticks([0,20,40])
# ax.set_xticks([0,1,2])
ax.set_yticks([0,.25,.5])
ax.set_xlabel('time (s)')
plt.tight_layout()
fig.savefig(fig_dir/f'joint_force_{clip_idx}.svg',bbox_inches='tight',dpi=300)

In [None]:
clip_idx = 1
joints_policy = qposes[10:, joint_idxs]
joints_ref = ref_data[clip_idx,10:, joint_idxs].T
# ground_force_clip = ground_force[clip_idx,:, 0].T
fontsize = 12
fig, axs = plt.subplots(1, 1, figsize=(3, 2))
# for i in range(joints_policy.shape[1]):
i = 9
ax = axs
ax.plot(np.linspace(0,2,joints_ref.shape[0]),np.rad2deg(joints_ref[:, i]),c='k', label=f'Data', lw=2)
ax.plot(np.linspace(0,2,joints_policy.shape[0]),np.rad2deg(joints_policy[:, i]),c='#5fc1ffff', label=f'Policy',lw=1)
ax.legend(frameon=False,fontsize=fontsize,loc='upper left',bbox_to_anchor=(.01,1.3),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=2,columnspacing=.1)
ax.set_yticks([0,20,40])
ax.set_xticks([0,1,2])
ax.set_xlabel('time (s)')
ax.set_ylabel('R1 femur-tibia \n angle (deg)')

In [None]:
render_path = Path(cfg.dataset.rendering_mjcf).parent / 'fruitfly_force_pair.xml'   
qposes_rollout = np.array([state.pipeline_state.qpos for state in rollout])
ref_traj = env._get_reference_clip(rollout[0].info)

repeats_per_frame = 1 
spec = mujoco.MjSpec()
spec = spec.from_file(render_path.as_posix())
mj_model = spec.compile()
# position = reference_clip.position
# position = position.at[:,:,2].set(position[:,:,2] - 0.005)
# reference_clip = reference_clip.replace(position=position)
qposes_ref = np.concatenate([reference_clip.position,reference_clip.quaternion,reference_clip.joints], axis=-1)[0]

mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = env.sys.mj_model.opt.timestep

mj_data = mujoco.MjData(mj_model)

site_names = [
    mj_model.site(i).name
    for i in range(mj_model.nsite)
    if "-0" in mj_model.site(i).name
]
site_id = [
    mj_model.site(i).id
    for i in range(mj_model.nsite)
    if "-0" in mj_model.site(i).name
]
for id in site_id:
    mj_model.site(id).rgba = [1, 0, 0, 1]

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

# save rendering and log to wandb
os.environ["MUJOCO_GL"] = "osmesa"
mujoco.mj_kinematics(mj_model, mj_data)
# renderer = mujoco.Renderer(mj_model, height=512, width=512)

frames = []
# render while stepping using mujoco
# video_path = f"{model_path}/{num_steps}.mp4"
# with imageio.get_writer(video_path, fps=50) as video:

with mujoco.Renderer(mj_model, height=512, width=512) as renderer:
    for qpos1, qpos2 in zip(qposes_rollout, qposes_ref):
        mj_data.qpos = np.append(qpos1, qpos2)
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera=2, scene_option=scene_option)
        pixels = renderer.render()
        # video.append_data(pixels)
        frames.append(pixels)
        
media.show_video(frames, fps=25)


In [None]:
ctrl_all = jp.array(ctrl_all)
model_path = "/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_viz_only.xml"
spec = mujoco.MjSpec()
spec = spec.from_file(model_path)
thorax = spec.find_body("thorax")
first_joint = thorax.first_joint()
# first_joint.delete()
root = spec.compile()
root.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}[cfg.dataset.env_args.solver.lower()]
root.opt.iterations = env_args.iterations
root.opt.ls_iterations = env_args.ls_iterations
root.opt.timestep = env_args.physics_timestep
root.opt.jacobian = 0
data = mujoco.MjData(root)
data.qpos = qposes_rollout[0]
mujoco.mj_forward(root, data)
n_frames = 150# ctrl_all.shape[0]
height = 1024
width = 1024
frames = []
fps = 1/env.dt
times = []
sensordata = []
scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 0, 0, 0, 0, 0]
scene_option.geomgroup[:] = [1, 1, 0, 0, 0, 0]

scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

qpos_all,rollout,ncon_all = [],[],[]
with mujoco.Renderer(root, height, width) as renderer:
    for i in range(n_frames):
        data.ctrl = ctrl_all[i].copy()
        data.qpos = qposes_rollout[i].copy()
        while data.time < i/fps:
            mujoco.mj_step(root, data)
            sensordata.append(data.sensordata.copy())
        times.append(data.time)
        renderer.update_scene(data,camera='track1',scene_option=scene_option)
        frame = renderer.render()
        frames.append(frame)
        qpos_all.append(data.qpos.copy())
        ncon_all.append(data.ncon)
        rollout.append(data)

media.show_video(frames, fps=50)


In [None]:
media.write_image('/data/users/eabe/biomech_model/Flybody_imitation/RL_Flybody_imitation/debug/figures/walking.png',frames[134])

In [None]:
end_eff = [
'claw_T1_left',
'claw_T1_right',
'claw_T2_left',
'claw_T2_right',
'claw_T3_left',
'claw_T3_right',
]

In [None]:
# 1 (mg cm) / (s^2) = 1.0 × 10-8 newtons

In [None]:
N = 100
sdata = 1e-8*(np.stack(sensordata).reshape(-1,6,3)) # Time x end_eff x xyz, x=forward
sdata = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=0, arr=sdata)

fig, axs = plt.subplots(3, 2, figsize=(10, 10), sharey=True)
axs = axs.flatten()
for n in range(len(end_eff)):
    ax = axs[n]
    ax.plot(sdata[:,n,0])
    ax.plot(sdata[:,n,1])
    ax.plot(sdata[:,n,2])
# plt.plot(sdata[:,:,2])

In [None]:
latents = jp.stack([extras_all[n]['z'] for n in range(len(extras_all))])

In [None]:
fig, axs = plt.subplots(6,10, figsize=(15,10),sharex=True,sharey=True)
axs = axs.flatten()
for n in range(latents.shape[-1]):
    ax = axs[n]
    ax.hist(latents[:,n],bins=50)
plt.tight_layout()


In [None]:
from sklearn.decomposition import PCA
import seaborn as sns
import pandas as pd
pca = PCA()
embeded = pca.fit_transform(latents)
plt.plot(np.cumsum(pca.explained_variance_ratio_))

In [None]:
df = pd.DataFrame(latents[:,:4])
df

In [None]:
sns.pairplot(df)

In [None]:
sensordata = jp.stack([state.sensordata for state in rollout2])
force_out = sensordata.reshape(-1,6,3)


In [None]:
plt.plot(force_out[:,0,2])

# Test reward functions

In [None]:
env._reference_clip.quaternion[0], reference_clip.quaternion[0],state.pipeline_state.qpos[3:7]
env._reference_clip.position[0], reference_clip.position[0],state.pipeline_state.qpos[:3]
env._reference_clip.joints[0], reference_clip.joints[0],state.pipeline_state.qpos[7:]


In [None]:
# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
rollout = [state.pipeline_state]

# grab a trajectory
for i in range(100):
    # ctrl = ctrl_mpc[i]
    # ctrl = 5*jax.random.normal(rng, shape=(env.sys.nu,), dtype=jp.float32)
    ctrl = jp.zeros((env.sys.nu,))
    # dpos =jp.array([jp.sum(reference_clip.lin_vel_y[:i]*env.dt),0,0])
    # qpos_t = jp.concatenate((reference_clip.position[i]+dpos,reference_clip.quaternion[i], reference_clip.joints[i]))
    # qpos_t = jp.concatenate((reference_clip.position[i],reference_clip.quaternion[i], reference_clip.joints[i]))
    # qvel_t = jp.concatenate((reference_clip.velocity[i],reference_clip.angular_velocity[i], reference_clip.joints_velocity[i]))
    # pipeline = state.pipeline_state.replace(qpos=qpos_t, qvel=qvel_t)
    # pipeline = state.pipeline_state.replace(qpos=reference_clip.joints[i], qvel=reference_clip.joints_velocity[i])
    # state=state.replace(pipeline_state=pipeline)
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]

scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

pixels = env.render(rollout, width=256, height=256, camera=1,scene_option=scene_option) 


In [None]:
# print(env.time)
scene_option = mujoco.MjvOption()
scene_option.geomgroup[:] = [1, 1, 0, 0, 0, 0]
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]

scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
# pixels = env.render(rollout, width=1000, height=1000, camera=1,scene_option=scene_option) 
media.show_video(pixels, fps=50)

In [None]:
_pos_reward_weight = 1
_joint_reward_weight = 1
_angvel_reward_weight = 1
_bodypos_reward_weight = 1
_endeff_reward_weight = 1
_quat_reward_weight = 1
_pos_scaling = 400.0
_joint_scaling = 0.25
_angvel_scaling = 0.5
_bodypos_scaling = 4.0
_endeff_scaling = 0.05
_quat_scaling = 500.0
clip_idx=0
rewards = {'pos':[],'joint':[], 'angvel':[], 'bodypos':[], 'endeff':[],'pos_reward':[], 'joint_reward':[], 'angvel_reward':[], 'bodypos_reward':[], 'endeff_reward':[], 'quat_distance':[], 'quat_reward':[]}
for cur_frame in range (1000): 
    data= rollout[cur_frame].pipeline_state
    
    quat_track = reference_clip.position[clip_idx,cur_frame]
    pos_distance = jp.sum((data.qpos[:3] - quat_track)**2)
    pos_reward = _pos_reward_weight * jp.exp(-_pos_scaling * pos_distance)
    
    quat_track = reference_clip.quaternion[clip_idx,cur_frame]
    quat_distance = jp.sum(_bounded_quat_dist(data.qpos[3:7], quat_track) ** 2)
    quat_reward = _quat_reward_weight * jp.exp(-_quat_scaling * quat_distance)
    
    joint_track = reference_clip.joints[clip_idx,cur_frame]
    joint_distance = jp.sum((data.qpos[7:] - joint_track) ** 2)
    joint_reward = _joint_reward_weight * jp.exp(-_joint_scaling * joint_distance)
    
    angvel_track = reference_clip.angular_velocity[clip_idx,cur_frame]
    angvel_distance = jp.sum((data.qvel[3:6] - angvel_track) ** 2)
    angvel_reward = _angvel_reward_weight * jp.exp(-_angvel_scaling * angvel_distance)
    
    bodypos_track = reference_clip.body_positions[clip_idx,cur_frame]
    bodypos_distance = jp.sum((data.xpos[env._body_idxs]- bodypos_track[env._body_idxs]).flatten()** 2)
    bodypos_reward = _bodypos_reward_weight * jp.exp(-_bodypos_scaling* bodypos_distance)
    
    endeff_track = reference_clip.body_positions[clip_idx,cur_frame]
    endeff_distance = jp.sum((data.xpos[env._endeff_idxs]- endeff_track[env._endeff_idxs]).flatten()** 2)
    endeff_reward = _endeff_reward_weight * jp.exp(-_endeff_scaling* endeff_distance)

        
    rewards['pos'].append(pos_distance)
    rewards['joint'].append(joint_distance)
    rewards['angvel'].append(angvel_distance)
    rewards['bodypos'].append(bodypos_distance)
    rewards['endeff'].append(endeff_distance)
    rewards['quat_distance'].append(quat_distance)
    
    rewards['pos_reward'].append(pos_reward)
    rewards['joint_reward'].append(joint_reward)
    rewards['angvel_reward'].append(angvel_reward)
    rewards['bodypos_reward'].append(bodypos_reward)
    rewards['endeff_reward'].append(endeff_reward)
    rewards['quat_reward'].append(quat_reward)
    


In [None]:

fig, axs = plt.subplots(2,1, figsize=(10,10))
ax = axs[0]
# ax.plot(rewards['pos'], label='pos')
ax.plot(rewards['joint'], label='joint')
# ax.plot(rewards['angvel'], label='angvel')
ax.plot(rewards['bodypos'], label='bodypos')
ax.plot(rewards['endeff'], label='endeff')
ax.plot(rewards['quat_distance'], label='quat_distance')
ax.legend()

ax = axs[1]
# ax.plot(rewards['pos_reward'], label='pos_reward')
ax.plot(rewards['joint_reward'], label='joint_reward')
# ax.plot(rewards['angvel_reward'], label='angvel_reward')
ax.plot(rewards['bodypos_reward'], label='bodypos_reward')
ax.plot(rewards['endeff_reward'], label='endeff_reward')
ax.plot(rewards['quat_reward'], label='quat_reward')
ax.legend()
ax.set_ylim(-.1,1.1)



In [None]:
from jax.flatten_util import ravel_pytree
from brax import math as brax_math

rewards = {
    'summed_pos_distance_all': [],
    'target_pos_all': [],
    'tracking_lin_vel_all': [],
    'tracking_ang_vel_all': [],
    'ang_vel_xy_all': [],
    'lin_vel_z_all': [],
    'orientation_all': [],
    'torques_all': [],
    'action_rate_all': [],
    'stand_still_all': [],
    'termination_all': [],
    'foot_slip_all': [],
}
too_far_dist = 0.01
bad_pose_dist = 10.0
pos_reward_weight = 2.0 
tracking_lin_vel_weight = 0.0
tracking_ang_vel_weight = 0.0
lin_vel_z_weight = -5e-5
ang_vel_xy_weight = -1e-3
orientation_weight = 1.0
torques_weight = -0.0002
action_rate_weight = -0.01
stand_still_weight = -0.5
foot_slip_weight = -0.1
termination_weight = 1.0
pos_scaling = 200.0
linvel_scaling = 400.0
angvel_scaling = 4.0
ang_vel_xy_scaling = -1.0
lin_vel_z_scaling = -1.0
orientation_scaling = 5.0
torques_scaling = -1.0
action_rate_scaling = -1.0
stand_still_scaling = -1.0
foot_slip_scaling = -1.0
init_pos = env.sys.qpos0[:3]
target_trace = []
joint_distances = []
local_vel_all = []
term_conds = {
    'done': [],
    'nan': [],
    'fall': [],
    'too_far': [],
    'bad_pose': [],
    }
for cur_frame in range(len(rollout)): 
    data= rollout[cur_frame].pipeline_state
    
    info = state.info.copy()
    joint_angles = data.q[7:]
    joint_vel = data.qd[7:]
    x, xd = data.x, data.xd
    action = info['prev_ctrl']
    min_z, max_z = env._healthy_z_range
    is_healthy = jp.where(data.xpos[env._thorax_idx][2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.xpos[env._thorax_idx][2] > max_z, 0.0, is_healthy)
    fall = 1.0 - is_healthy
    # up = jp.array([0.0, 0.0, 1.0])
    # done = jp.int32(jp.dot(brax_math.rotate(up, x.rot[env._thorax_idx - 1]), up) < 0)

    reference_obs, proprioceptive_obs = env._get_obs(data, info, state.obs)
    obs = jp.concatenate([reference_obs, proprioceptive_obs])

    ##### Tracking position from velocity #####
    target_pos = env._init_pos + info['step']*jp.concatenate((info['command'][:2],jp.zeros(1)))*env.dt
    pos_distance = (data.qpos[:3] - target_pos)
    pos_reward = env._pos_reward_weight * jp.exp(
        -env._pos_scaling * jp.sum(pos_distance**2)
    )
    summed_pos_distance = jp.sum((pos_distance * jp.array([1.0, 1.0, 0.2])) ** 2)
    too_far = jp.where(summed_pos_distance > env._too_far_dist, 1.0, 0.0)
    ##### Tracking joint positions #####
    joint_distance = jp.sum((data.qpos[7:] - env._default_pose) ** 2)
    bad_pose = jp.where(joint_distance > env._bad_pose_dist, 1.0, 0.0)
    info['joint_distance'] = joint_distance
    joint_distances.append(joint_distance)
    target_trace.append(data.qpos[:3])
    # ##### Tracking quaternion flat orientation ######
    # quat_distance = jp.sum(
    #     _bounded_quat_dist(data.qpos[3:7], jp.array()) ** 2
    # )
    # quat_reward = env._quat_reward_weight * jp.exp(
    #     -env._quat_scaling * quat_distance
    # )
    
    
    from jax.flatten_util import ravel_pytree
    flattened_vals, _ = ravel_pytree(data)
    num_nans = jp.sum(jp.isnan(flattened_vals))
    nan = jp.where(num_nans > 0, 1.0, 0.0)
    done = jp.max(jp.array([nan, fall, too_far, bad_pose]))

    # foot contact data based on z-position
    foot_pos = data.site_xpos[env._endeff_idxs]  # pytype: disable=attribute-error
    foot_contact_z = foot_pos[:, 2] - env._foot_radius
    contact = foot_contact_z < 1e-3  # a mm or less off the floor
    contact_filt_mm = contact | state.info["last_contact"]
    contact_filt_cm = (foot_contact_z < 3e-2) | state.info["last_contact"]

    # Tracking of linear velocity commands (xy axes)
    local_vel = brax_math.rotate(xd.vel[0], brax_math.quat_inv(x.rot[0]))
    lin_vel_error = jp.sum(jp.square(info["command"][:2] - local_vel[:2]))
    lin_vel_reward = jp.exp(-env._linvel_scaling * lin_vel_error)
    info["bodypos_distance"] = lin_vel_error
    tracking_lin_vel = env._tracking_lin_vel_weight * lin_vel_reward
    local_vel_all.append(local_vel)
    # tracking_lin_vel = env._tracking_lin_vel_weight * env._reward_tracking_lin_vel(info['command'], x, xd)
    tracking_ang_vel = (
        env._tracking_ang_vel_weight
        * env._reward_tracking_ang_vel(info["command"], x, xd)
    )
    ang_vel_xy = jp.clip(env._ang_vel_xy_weight * env._reward_ang_vel_xy(xd),env._ang_vel_xy_scaling,0)
    lin_vel_z = jp.clip(env._lin_vel_z_weight * env._reward_lin_vel_z(xd),env._lin_vel_z_scaling,0)
    orientation = env._orientation_weight * jp.exp(-env._orientation_scaling*env._reward_orientation(x))
    torques = jp.clip(env._torques_weight * env._reward_torques(data.qfrc_actuator),env._torques_scaling,0)
    action_rate = jp.clip(env._action_rate_weight * env._reward_action_rate(
        action, info["prev_ctrl"]
    ), env._action_rate_scaling,0)
    stand_still = jp.clip(env._stand_still_weight * env._reward_stand_still(
        info["command"],
        joint_angles,
    ),env._stand_still_scaling,0)
    foot_slip = jp.clip(env._foot_slip_weight*env._reward_foot_slip(data, contact_filt_cm),env._foot_slip_scaling,0)
    termination = env._termination_weight * env._reward_termination(
        done, info["step"]
    )

    reward = (
        pos_reward
        + tracking_lin_vel
        + tracking_ang_vel
        + ang_vel_xy
        + lin_vel_z
        + orientation
        + torques
        + action_rate
        + stand_still
        + foot_slip
        + termination
    ) 
    
    term_conds['done'].append(done)
    term_conds['nan'].append(nan)
    term_conds['fall'].append(fall)
    term_conds['too_far'].append(too_far)
    term_conds['bad_pose'].append(bad_pose)
    rewards["summed_pos_distance_all"].append(summed_pos_distance)
    rewards['target_pos_all'].append(pos_reward)
    rewards['tracking_lin_vel_all'].append(tracking_lin_vel)
    rewards['tracking_ang_vel_all'].append(tracking_ang_vel)
    rewards['ang_vel_xy_all'].append(ang_vel_xy)
    rewards['lin_vel_z_all'].append(lin_vel_z)
    rewards['orientation_all'].append(orientation)
    rewards['torques_all'].append(torques)
    rewards['action_rate_all'].append(action_rate)
    rewards['stand_still_all'].append(stand_still)
    rewards['foot_slip_all'].append(foot_slip)
    rewards['termination_all'].append(termination)

joint_distances = jp.stack(joint_distances)
target_trace = jp.stack(target_trace)
local_vel_all = jp.stack(local_vel_all)
for key,val in rewards.items():
    rewards[key] = np.stack(val)
for key,val in term_conds.items():
    term_conds[key] = np.stack(val)

In [None]:
print([key for key in rewards.keys()])

In [None]:
termination_trace = jp.array([rollout[i].metrics['termination'] for i in range(len(rollout))])
tracking_error = jp.array([rollout[i].info['bodypos_distance'] for i in range(len(rollout))])
commands = jp.array([rollout[i].info['command'] for i in range(len(rollout))])
vels_direct = jp.array([rollout[i].pipeline_state.xd.vel[0] for i in range(len(rollout))])
vels_direct = jp.array([rollout[i].pipeline_state.xd.vel[0] for i in range(len(rollout))])
pos_all = jp.array([rollout[i].pipeline_state.xpos[1] for i in range(len(rollout))])

In [None]:
local_vel[:2],xd.vel[0,:2], info['command']

In [None]:
fig, axs = plt.subplots(1,1,figsize=(10,5))
ax = axs
# for key, val in rewards.items():
#     ax.plot(val,label=key)
# for key, val in term_conds.items():
#     ax.plot(val,label=key)
# ax.legend(loc='upper right')
# ax.plot(rewards['target_pos_all'])
# ax.plot(rewards['summed_pos_distance_all'])
# ax.plot(rewards['tracking_linl_all'])
# ax.plot(jp.clip(rewards['lin_vel_z_all'],lin_vel_z_scaling,0))
# ax.plot(jp.clip(rewards['ang_vel_xy_all'],ang_vel_xy_scaling,0))
# ax.plot(rewards['orientation_all'])
# ax.plot(jp.clip(rewards['torques_all'],torques_scaling))
# ax.plot(jp.clip(rewards['action_rate_all'],stand_still_scaling))
# ax.plot(jp.clip(rewards['stand_still_all'],action_rate_scaling))
# ax.plot(np.clip(rewards['foot_slip_all'],foot_slip_scaling,0))
# ax.plot(rewards['termination_all'],'r',lw=2)
# ax.plot(jp.clip(sum(rewards.values()), 0.0, 10000.0),'k',lw=2)
# ax.plot(target_trace)
ax.plot(pos_all)
# ax.plot(jp.exp(-400*tracking_error[:100]))
# ax.plot(np.diff(commands[:,0]))
# ax.plot(jp.exp(-20*tracking_error))
# ax.plot(joint_distances)
# ax.plot(local_vel_all[:,:2])
# ax.axhline(y = jp.mean(vels_direct[:,:2],axis=0)[1])
# ax.plot(jp.mean(vels_direct[:,:2],axis=0))
# jp.mean(vels_direct[:,:2],axis=0)

In [None]:
local_vel[0]

In [None]:
tracking_error.shape

In [None]:
render_path = Path(cfg.dataset.rendering_mjcf).parent / 'fruitfly_force_pair.xml'

In [None]:
ref_traj

In [None]:
qposes_rollout = np.array([state.pipeline_state.qpos for state in rollout])
    
ref_traj = env._get_reference_clip(rollout[0].info)

repeats_per_frame = 1 #env._steps_for_cur_frame #int(1/(env._mocap_hz*env.sys.mj_model.opt.timestep))
spec = mujoco.MjSpec()
spec = spec.from_file(render_path.as_posix())
# thorax0 = spec.find_body("thorax-0")
# first_joint0 = thorax0.first_joint()
# if (env._free_jnt == False) & ('free' in first_joint0.name):
#     qposes_ref = np.repeat(
#         ref_traj.joints,
#         repeats_per_frame,
#         axis=0,
#     )
# qposes_ref = ref_traj.joints.copy()
position = reference_clip.position
position = position.at[:,:,2].set(position[:,:,2] - 0.005)
reference_clip = reference_clip.replace(position=position)
qposes_ref = np.concatenate([reference_clip.position,reference_clip.quaternion,reference_clip.joints], axis=-1)[0]
# first_joint0.delete()
    # thorax1 = spec.find_body("thorax-1")
    # first_joint1 = thorax1.first_joint()
    # first_joint1.delete()
# elif env._free_jnt == True: 
#     # qposes_ref = np.hstack([ref_traj.position, ref_traj.quaternion, ref_traj.joints])
#     qposes_ref = np.repeat(
#         np.hstack([ref_traj.position, ref_traj.quaternion, ref_traj.joints]),
#         repeats_per_frame,
#         axis=0,
#     )
    
mj_model = spec.compile()

mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = env.sys.mj_model.opt.timestep

mj_data = mujoco.MjData(mj_model)

site_names = [
    mj_model.site(i).name
    for i in range(mj_model.nsite)
    if "-0" in mj_model.site(i).name
]
site_id = [
    mj_model.site(i).id
    for i in range(mj_model.nsite)
    if "-0" in mj_model.site(i).name
]
for id in site_id:
    mj_model.site(id).rgba = [1, 0, 0, 1]

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

# save rendering and log to wandb
os.environ["MUJOCO_GL"] = "osmesa"
mujoco.mj_kinematics(mj_model, mj_data)
# renderer = mujoco.Renderer(mj_model, height=512, width=512)

frames = []
# render while stepping using mujoco
# video_path = f"{model_path}/{num_steps}.mp4"
# with imageio.get_writer(video_path, fps=50) as video:
with mujoco.Renderer(mj_model, height=512, width=512) as renderer:
    for qpos1, qpos2 in zip(qposes_rollout, qposes_ref):
        mj_data.qpos = np.append(qpos1, qpos2)
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera=1, scene_option=scene_option)
        pixels = renderer.render()
        # video.append_data(pixels)
        frames.append(pixels)

In [None]:
qpos1.shape,qpos2.shape

In [None]:
media.show_image(frames[10])

In [None]:
media.show_video(frames, fps=50)


In [None]:
mpc1_data = ioh5.load('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/fruitfly_qpos2_002ms.h5')
ctrl_mpc = jp.array(mpc1_data['ctrl'].T)
qpos_mpc = jp.array(mpc1_data['qpos'].T)

In [None]:
model_path = ("/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fast.xml")
spec = mujoco.MjSpec()
spec.from_file(model_path)
thorax = spec.find_body("thorax")
first_joint = thorax.first_joint()
first_joint.delete()
mj_model = spec.compile()

mj_model.opt.timestep = 0.0002
data = mujoco.MjData(mj_model)
mujoco.mj_forward(mj_model, data)
n_frames = 100
height = 240
width = 320
frames = []
fps = 1/.002
times = []
# sensordata = []
scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]

scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
qpos_rollout = []
qvel_rollout = []
xpos_rollout = []
xquat_rollout = []
with mujoco.Renderer(mj_model, height, width) as renderer:
    for i in range(n_frames):
        # data.ctrl = ctrl_mpc[i]
        # data.qpos = np.concatenate([np.zeros(4),qpos_mpc[i]])
        # data.qpos = reference_clip.joints[i]
        data.qpos = np.concatenate([np.zeros(4),reference_clip.joints[i]])
        # mujoco.mj_forward(mj_model, data)
        # while data.time < i/fps:
        mujoco.mj_step(mj_model,data)
        qpos_rollout.append(data.qpos.copy())
        qvel_rollout.append(data.qvel.copy())
        xpos_rollout.append(data.xpos.copy())
        xquat_rollout.append(data.xquat.copy())
        # times.append(data.time)
        #     sensordata.append(data.sensor('force_tarsus_T1_left').data.copy())
        renderer.update_scene(data,camera=1,scene_option=scene_option)
        frame = renderer.render()
        frames.append(frame)
        
    
# # grab a trajectory
# for i in range(times.shape[0]):
#     data.ctrl = ctrl[i]
#     state = mujoco.mj_step(mj_model,data)
#     rollout.append(state)

media.show_video(frames, fps=50)


In [None]:
qpos_rollout = jp.array(qpos_rollout)
qvel_rollout = jp.array(qvel_rollout)
xpos_rollout = jp.array(xpos_rollout)
xquat_rollout = jp.array(xquat_rollout)

In [None]:
env.dt