# Imports set up rendering

In [None]:
#@title Check if MuJoCo installation was successful

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '      
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
# %env MUJOCO_GL=egl
os.environ['MUJOCO_GL'] = 'egl'
os.environ['PYOPENGL_PLATFORM'] = 'egl'
try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=True "
    # "--xla_gpu_enable_async_collectives=true "
    # "--xla_gpu_enable_latency_hiding_scheduler=true "
    # "--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ['CUDA_VISIBLE_DEVICES'] = '1' # Use GPU 1
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_triton_gemm_any=True "
)

In [None]:
%load_ext autoreload
%autoreload 2
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use GPU 1
import functools
import jax
# jax.config.update("jax_enable_x64", True)

n_gpus = jax.device_count(backend="gpu")
print(f"Using {n_gpus} GPUs")
from typing import Dict
from brax import envs
import mujoco
import pickle
import warnings
import mediapy as media
import hydra
import jax.numpy as jp

from omegaconf import DictConfig, OmegaConf
from brax.training.agents.ppo import networks as ppo_networks
from custom_brax import custom_ppo as ppo
from custom_brax import custom_wrappers
from custom_brax import custom_ppo_networks
from orbax import checkpoint as ocp
from flax.training import orbax_utils
from preprocessing.mjx_preprocess import process_clip_to_train
from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking
from utils.utils import *
from utils.fly_logging import log_eval_rollout

warnings.filterwarnings("ignore", category=DeprecationWarning)
# jax.config.update("jax_enable_x64", True)

from hydra import initialize, compose
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra

##### Plotting settings ######
import matplotlib as mpl
mpl.rcParams.update({'font.size':          10,
                     'axes.linewidth':     2,
                     'xtick.major.size':   5,
                     'ytick.major.size':   5,
                     'xtick.major.width':  2,
                     'ytick.major.width':  2,
                     'axes.spines.right':  False,
                     'axes.spines.top':    False,
                     'pdf.fonttype':       42,
                     'xtick.labelsize':    10,
                     'ytick.labelsize':    10,
                     'figure.facecolor':   'white',
                     'pdf.use14corefonts': True,
                     'svg.fonttype':       'none',
                     'font.family':        'sans-serif',
                    #  'font.family':        'Arial',
                    #  'font.sans-serif':    'Arial',
                     'font.serif':         'Arial',
                    })

# Load configs

In [None]:
base_dir ='/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt'
run_cfg_list = sorted(list(Path(base_dir).rglob('run_config.yaml')))
for n, run_cfg in enumerate(run_cfg_list):
    print(n, run_cfg)


cfg_num = 6
cfg = OmegaConf.load(run_cfg_list[cfg_num])
run_id = int(run_cfg_list[cfg_num].parent.parent.stem.split('=')[1])
print(cfg.dataset.dname, run_id)
fig_dir = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/debug/figures')

In [15]:
dataset = cfg.dataset.dname
with initialize(version_base=None, config_path="configs"):
    cfg_temp=compose(config_name='config.yaml',overrides= [f"dataset={dataset}", f"train=train_{dataset}", "paths=walle", "version=ckpt", f'run_id={run_id}'],return_hydra_config=True,)
    HydraConfig.instance().set_config(cfg_temp)
    
cfg.paths = cfg_temp.paths

In [16]:
for k in cfg.paths.keys():
    if (k != 'user'):
        cfg.paths[k] = Path(cfg.paths[k])
        cfg.paths[k].mkdir(parents=True, exist_ok=True)
env_cfg = cfg.dataset
env_args = cfg.dataset.env_args

cfg.paths.base_dir = cfg.paths.base_dir.parent / 'ckpt'
reference_path = cfg.paths.data_dir/ f"clips/all_clips_batch_interp.p"
# reference_path = cfg.paths.data_dir/ f"clips/{cfg.dataset['clip_idx']}"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)
ref_data = np.concatenate([reference_clip.position,reference_clip.quaternion,reference_clip.joints], axis=-1)

# Load env

In [None]:
from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking, FlyRunSim, _bounded_quat_dist
# dataset = 'multiclip'

# with initialize(version_base=None, config_path="configs"):
#     cfg=compose(config_name='config.yaml',overrides= [f"dataset=fly_{dataset}", f"train=train_fly_{dataset}", "paths=walle"],return_hydra_config=True,)
#     HydraConfig.instance().set_config(cfg)


# env_args = cfg.dataset.env_args
envs.register_environment("fly_freejnt_clip", FlyTracking)
envs.register_environment("fly_freejnt_multiclip", FlyMultiClipTracking)
envs.register_environment("fly_run_policy", FlyRunSim)
# cfg.dataset.env_args.mjcf_path = '/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_viz_only.xml'
# cfg.dataset.env_args.mjcf_path = '/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fast.xml'
print(cfg.train.env_name)
cfg.dataset.env_args.iterations = 12
cfg.dataset.env_args.ls_iterations = 12
env = envs.get_environment(
    cfg.train.env_name,
    reference_clip=reference_clip,
    **cfg.dataset.env_args,
)

In [18]:
rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
    
# rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
# define the jit reset/step functions
jit_reset = jax.jit(rollout_env.reset)
jit_step = jax.jit(rollout_env.step)
state = jit_reset(jax.random.PRNGKey(0))


# Test inference

In [19]:
from orbax import checkpoint as ocp
from flax.training import orbax_utils
import optax

In [None]:
model_path = Path(cfg.paths.ckpt_dir/f"{run_id}")
##### Get all the checkpoint files #####
ckpt_files = sorted([Path(f.path) for f in os.scandir(model_path) if f.is_dir()])
max_ckpt = ckpt_files[-1].as_posix()
env_args = cfg.dataset.env_args
print(max_ckpt)

In [None]:
def policy_params_fn(num_steps, make_policy, params, policy_params_fn_key, model_path=model_path):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21356039/ckpt/Test_path') / f'{num_steps}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)
  
  
episode_length = (env_args.clip_length - 50 - env_cfg.ref_traj_length) * env._steps_for_cur_frame
print(f"episode_length {episode_length}")

train_fn = functools.partial(
            ppo.train,
            num_envs=cfg.train["num_envs"],
            num_timesteps=0,
            num_evals=1,
            num_resets_per_eval=cfg.train['num_resets_per_eval'],
            reward_scaling=cfg.train['reward_scaling'],
            episode_length=episode_length,
            normalize_observations=True,
            action_repeat=cfg.train['action_repeat'],
            clipping_epsilon=cfg.train["clipping_epsilon"],
            unroll_length=cfg.train['unroll_length'],
            num_minibatches=cfg.train["num_minibatches"],
            num_updates_per_batch=cfg.train["num_updates_per_batch"],
            discounting=cfg.train['discounting'],
            learning_rate=cfg.train["learning_rate"],
            kl_weight=cfg.train["kl_weight"],
            entropy_cost=cfg.train['entropy_cost'],
            batch_size=cfg.train["batch_size"],
            seed=cfg.train['seed'],
            network_factory=functools.partial(
                custom_ppo_networks.make_intention_ppo_networks,
                encoder_hidden_layer_sizes=cfg.train['encoder_hidden_layer_sizes'],
                decoder_hidden_layer_sizes=cfg.train['decoder_hidden_layer_sizes'],
                value_hidden_layer_sizes=cfg.train['value_hidden_layer_sizes'],
            ),
            restore_checkpoint_path=max_ckpt,
        )

make_inference_fn, params, _= train_fn(environment=env,)
params2 = (params[0],params[1].policy)
policy_params = (params[0],params[1].policy)
# Env_steps = params[2]
jit_inference_fn = jax.jit(make_inference_fn(policy_params, deterministic=True))

In [None]:
env._n_clips, env._clip_length

In [None]:
rollout_data = {'clip{:02d}'.format(n): {} for n in range(env._n_clips)}
for n in range(env._n_clips):
    # n = 2
    # reset_rng, act_rng = jax.random.split(policy_params_fn_key)
    rng = jax.random.PRNGKey(0)
    reset_rng, act_rng = jax.random.split(rng)
    state = jit_reset(reset_rng)
    state.info['clip_idx'] = n
    print(state.info['clip_idx'])

    rollout = [state]
    # rollout_len = env_args["clip_length"]*int(rollout_env._steps_for_cur_frame)
    rollout_len = env._clip_length
    ctrl_all = []
    for i in range(rollout_len):
        _, act_rng = jax.random.split(act_rng)
        obs = state.obs
        ctrl, extras = jit_inference_fn(obs, act_rng)
        state = jit_step(state, ctrl)
        ctrl_all.append(ctrl.copy())
        rollout.append(state)
    rollout2 = [state.pipeline_state for state in rollout]
    rollout_data['clip{:02d}'.format(n)]['qposes'] = jp.stack([state.qpos for state in rollout2])
    rollout_data['clip{:02d}'.format(n)]['qvels'] = jp.stack([state.qvel for state in rollout2])
    rollout_data['clip{:02d}'.format(n)]['ctrl'] = jp.stack(ctrl_all)
# rollout_data['clip{:02d}'.format(n)]['sensordata'] = jp.stack([state.sensordata for state in rollout2])


In [None]:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

pixels = rollout_env.render(rollout2[500:], camera='track2', width=512, height=512, scene_option=scene_option)
media.show_video(pixels,fps=50)

In [None]:
run_id

In [38]:
stacked_data = {}
for key2 in rollout_data['clip00'].keys():
    stacked_data[key2] = np.stack([rollout_data['clip{:02d}'.format(n)][key2] for n in range(env._n_clips)])
# ioh5.save('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/amp_control_clip.h5', amp_data)
ioh5.save(f'/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Torque_control_multiclip_{run_id}.h5', stacked_data)


In [41]:
policy_data = ioh5.load(f'/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Torque_control_multiclip_{run_id}.h5')
amp_data = ioh5.load('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/amp_control_clip.h5')

# ref_data = ioh5.load('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/ref_multiclip_qpos.h5')

In [None]:
amp_data.keys(), policy_data.keys()

In [None]:
N = 1
nclips = amp_data['sensordata'].shape[0]
sensordata = amp_data['sensordata'].reshape(nclips,amp_data['sensordata'].shape[1],-1,3)
sdata = 1e-8*sensordata
# sdata = 1e-8*(np.stack(sensordata).reshape(nclips,sensordata.shape[1],-1,3)) # Time x end_eff x xyz, x=forward
sdata = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=sdata)
end_eff = [
# 'claw_T1_left',
'claw_T1_right',
'claw_T2_left',
'claw_T2_right',
'claw_T3_left',
'claw_T3_right',
]

clip_idx = 2
t=500
dt=10000
fig, axs = plt.subplots(3, 2, figsize=(10, 10), sharey=True)
axs = axs.flatten()
for n in range(len(end_eff)):
    ax = axs[n]
    ax.plot(sdata[clip_idx,t:t+dt,n,0], label='x')
    ax.plot(sdata[clip_idx,t:t+dt,n,1], label='y')
    ax.plot(sdata[clip_idx,t:t+dt,n,2], label='z')
    ax.set_title(end_eff[n])
    ax.set_xlabel('timesteps')
    ax.set_ylabel('force (N)')
axs[0].legend()
plt.tight_layout()
# fig.savefig(fig_dir/'claw_force.png', dpi=300)

## Data saving

In [25]:

def find_onsets_offsets(signal, threshold=0.5):
    """
    Finds the onsets and offsets of a square signal.

    Args:
        signal: The input signal (numpy array).
        threshold: The threshold for determining onsets and offsets (default: 0.5).

    Returns:
        onsets: A list of indices corresponding to the onsets.
        offsets: A list of indices corresponding to the offsets.
    """

    onsets = []
    offsets = []

    state = 0  # 0: low, 1: high

    for i, value in enumerate(signal):
        if state == 0 and value > threshold:
            onsets.append(i)
            state = 1
        elif state == 1 and value < threshold:
            offsets.append(i)
            state = 0

    return onsets, offsets


In [None]:
policy_data.keys()


In [None]:

amp_force = amp_data['sensordata'].reshape(1,amp_data['sensordata'].shape[1],-1,3)
ctrl_force = policy_data['sensordata'].reshape(17,policy_data['sensordata'].shape[1],-1,3)

amp_force = np.array(amp_force)
ctrl_force = np.array(ctrl_force)

amp_force[np.abs(amp_force)<1e-4]=0
amp_force_clip = amp_force.copy()
ctrl_force[np.abs(ctrl_force)<1e-4]=0
ctrl_force_clip = ctrl_force.copy()
amp_force_clip.shape, ctrl_force_clip.shape # nclips x timesteps x end_eff x xyz

In [None]:
N = 10
amp_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=amp_force_clip)
amp_binary = np.zeros_like(amp_smooth)
amp_binary[np.abs(amp_smooth)>0.01] = 1
ctrl_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=ctrl_force_clip)
ctrl_binary = np.zeros_like(ctrl_smooth)
ctrl_binary[np.abs(ctrl_smooth)>0.01] = 1

In [None]:

def calculate_peakforces(sdata_binary,sdata):
    onsets_all,offsets_all = [],[]
    for clip_idx in range(sdata_binary.shape[0]):
        clip_onsets, clip_offsets = [],[]
        for end_eff_idx in range(sdata_binary.shape[2]):
            dim_onsets, dim_offsets = [],[]
            for dim in range(sdata_binary.shape[3]):
                onsets, offsets = find_onsets_offsets(sdata_binary[clip_idx,:,end_eff_idx,dim])
                dim_onsets.append(onsets)
                dim_offsets.append(offsets)
            clip_onsets.append(dim_onsets)
            clip_offsets.append(dim_offsets)
        onsets_all.append(clip_onsets)
        offsets_all.append(clip_offsets)
        
    sdata_square = np.zeros_like(sdata)
    sdata_square.shape
    peak_force_all,peak_force_all_std = [],[]
    for clip_idx,(clip_onsets,clip_offsets) in enumerate(zip(onsets_all,offsets_all)):
        peak_force_end_eff,peak_force_end_eff_std = [],[]
        for end_eff_idx,(end_eff_onsets,end_eff_offsets) in enumerate(zip(clip_onsets,clip_offsets)):
            peak_force_dim, peak_force_std = [], []
            for dim,(dim_on,dim_off) in enumerate(zip(end_eff_onsets,end_eff_offsets)):
                peak_force = []
                for on,off in zip(dim_on,dim_off):
                    if off-on<10:
                        # print(on,off)
                        continue
                    else:
                        sign = np.sign(np.mean(sdata[clip_idx,on:off,end_eff_idx,dim]))
                        sdata_square[clip_idx,on:off,end_eff_idx,dim] = sign * np.max(np.abs(sdata[clip_idx,on:off,end_eff_idx,dim]))
                        peak_force.append(np.max(np.abs(sdata[clip_idx,on:off,end_eff_idx,dim])))
                peak_force_dim.append(np.nanmean(peak_force))
                peak_force_std.append(np.nanstd(peak_force))
            peak_force_end_eff.append(peak_force_dim)
            peak_force_end_eff_std.append(peak_force_std)
        peak_force_all.append(peak_force_end_eff)
        peak_force_all_std.append(peak_force_end_eff_std)
    peak_force_all = np.stack(peak_force_all)
    peak_force_all_std = np.stack(peak_force_all_std)
    return peak_force_all, peak_force_all_std, sdata_square

In [None]:
peak_force_amp,peak_force_amp_std, amp_square = calculate_peakforces(amp_binary,amp_force_clip)
peak_force_ctrl,peak_force_ctrl_std, ctrl_square = calculate_peakforces(ctrl_binary,ctrl_force_clip)
amp_mean_peak_force = np.nanmean(peak_force_amp,axis=0)
amp_std_peak_force = np.nanstd(peak_force_amp_std,axis=(0))
ctrl_mean_peak_force = np.nanmean(peak_force_ctrl,axis=0)
ctrl_std_peak_force = np.nanstd(peak_force_ctrl_std,axis=(0))
amp_std_peak_force.shape,ctrl_std_peak_force.shape

In [None]:
end_eff_ctrl = [
'claw_T1_left',
'claw_T1_right',
'claw_T2_left',
'claw_T2_right',
'claw_T3_left',
'claw_T3_right',
]
end_eff_amp = [
'claw_T1_right',
'claw_T2_left',
'claw_T2_right',
'claw_T3_left',
'claw_T3_right',
]

In [None]:

ctrl_mean_peak_force[end_eff_idx+1,:]

In [None]:
fontsize=12
end_eff_idx = 0
fig, axs = plt.subplots(3,2,figsize=(6,6))
axs = axs.flatten()
for end_eff_idx in range(len(end_eff_amp)):
    ax = axs[end_eff_idx+1]
    ax.bar(x=np.array([1,2,3]), height = 10*ctrl_mean_peak_force[end_eff_idx+1,:],width=0.25, yerr=10*ctrl_std_peak_force[end_eff_idx+1,:], color='r')
    ax.bar(x=np.array([1.25,2.25,3.25]), height = 10*amp_mean_peak_force[end_eff_idx,:],width=0.25, yerr=10*amp_std_peak_force[end_eff_idx,:], color='b')
    ax.set_xticks([1,2,3])
    ax.set_yticks([0,10,20,30])
    ax.set_xticklabels(['x','y','z'])
    ax.set_ylabel(f'{end_eff_amp[end_eff_idx]} \n peak force (nN)')
# ax.set_title(f'Peak force for {end_eff[end_eff_idx+1]}')
ax.legend(['Ctrl','Amp'],frameon=False,fontsize=fontsize,loc='upper right',bbox_to_anchor=(.6,1.0),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=1,columnspacing=.1)
fig.tight_layout()
# fig.savefig(fig_dir/f'peak_force_comparison.svg', dpi=300)

In [None]:
clip_idx = 10
end_eff_idx = 3
dim = 1
N = 10
# amp_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=amp_force_clip)
# amp_binary = np.zeros_like(amp_smooth)
# amp_binary[np.abs(amp_smooth)>0.01] = 1
ctrl_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=ctrl_force_clip)
ctrl_binary = np.zeros_like(ctrl_smooth)
ctrl_binary[np.abs(ctrl_smooth)>0.01] = 1
plt.plot(ctrl_force_clip[clip_idx,:,end_eff_idx,dim])
plt.plot(ctrl_smooth[clip_idx,:,end_eff_idx,dim])
plt.plot(ctrl_binary[clip_idx,:,end_eff_idx,dim])
plt.plot(ctrl_square[clip_idx,:,end_eff_idx,dim])

In [33]:
# ref_data = {'clip{:02d}'.format(n): {} for n in range(env._n_clips)}
# for n in range(env._n_clips):
#     rollout[0].info['clip_idx'] = n
#     ref_traj = env._get_reference_clip(rollout[0].info)

#     ref_data['clip{:02d}'.format(n)]['qpos'] = np.hstack([ref_traj.position, ref_traj.quaternion, ref_traj.joints])
# ioh5.save('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/ref_multiclip_qpos.h5',ref_data)

In [None]:
qposes_rollout = np.array([state.pipeline_state.qpos for state in rollout])

ref_traj = env._get_reference_clip(rollout[0].info)
print(f"clip_id:{rollout[0].info}")
qposes_ref = np.repeat(
    np.hstack([ref_traj.position, ref_traj.quaternion, ref_traj.joints]),
    env._steps_for_cur_frame,
    axis=0,
)
spec = mujoco.MjSpec()
spec = spec.from_file(cfg.dataset.rendering_mjcf)
mj_model = spec.compile()
print(cfg.dataset.rendering_mjcf)
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = env.sys.mj_model.opt.timestep

mj_data = mujoco.MjData(mj_model)
site_id = [
    mj_model.site(i).id
    for i in range(mj_model.nsite)
    if "-0" in mj_model.site(i).name
]
for id in site_id:
    mj_model.site(id).rgba = [1, 0, 0, 1]

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

# save rendering and log to wandb
os.environ["MUJOCO_GL"] = "osmesa"
mujoco.mj_kinematics(mj_model, mj_data)
# renderer = mujoco.Renderer(mj_model, height=512, width=512)

frames = []
# render while stepping using mujoco
with mujoco.Renderer(mj_model, height=512, width=512) as renderer:
    for qpos1, qpos2 in zip(qposes_rollout[500:], qposes_ref[500:]):
        mj_data.qpos = np.append(qpos1, qpos2)
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera=1, scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)

In [None]:
media.show_video(frames,fps=50)

In [None]:
# ctrl_all = jp.array(ctrl_all)
model_path = "/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fastviz.xml"
# model_path = "/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fast_amp.xml"
spec = mujoco.MjSpec()
spec = spec.from_file(model_path)
thorax = spec.find_body("thorax")
first_joint = thorax.first_joint()
# first_joint.delete()
root = spec.compile()
root.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}[cfg.dataset.env_args.solver.lower()]
root.opt.iterations = env_args.iterations +4
root.opt.ls_iterations = env_args.ls_iterations +4
root.opt.timestep = env_args.physics_timestep
root.opt.jacobian = 0
data = mujoco.MjData(root)
# data.qpos = qposes_rollout[0]
height = 512
width = 512
fps = 1/env.dt
times = []
scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]

scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
sensor_data_all = []
frames_all = []
for clip_idx in range(env._n_clips):
    print(clip_idx)
    ctrl_all = policy_data['ctrl'][clip_idx,:,:].copy()
    qposes = policy_data['qposes'][clip_idx,:,:].copy()
    n_frames = ctrl_all.shape[0]
    data = mujoco.MjData(root)
    data.qpos = qposes[0]
    mujoco.mj_forward(root, data)
    qpos_all,rollout,ncon_all = [],[],[]
    sensordata = []
    frames = []
    # with mujoco.Renderer(root, height, width) as renderer:
    for i in range(n_frames):
        data.ctrl = ctrl_all[i].copy()
        data.qpos = qposes[i].copy()
        while data.time < i/fps:
            mujoco.mj_step(root, data)
        sensordata.append(data.sensordata.copy())
            # times.append(data.time)
            # renderer.update_scene(data,camera='track1',scene_option=scene_option)
            # frame = renderer.render()
            # frames.append(frame)
            
            # qpos_all.append(data.qpos.copy())
            # ncon_all.append(data.ncon)
            # rollout.append(data)
    # frames_all.append(frames)
    sensor_data_all.append(sensordata)
# policy_data['sensordata'] = np.array(sensor_data_all)
# media.show_video(frames, fps=50)


In [45]:
# policy_data['sensordata'] = np.array(sensor_data_all)
# ioh5.save(f'/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Torque_control_multiclip_{run_id}.h5', policy_data)


In [None]:
qposes_rollout = policy_data['qposes']
qposes_ref = ref_data
ground_force = policy_data['sensordata']
ground_force.shape

# Figure making

In [48]:
policy_data = ioh5.load(f'/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Torque_control_multiclip_{run_id}.h5')
amp_data = ioh5.load('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/amp_control_clip.h5')

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)
ref_data = np.concatenate([reference_clip.position,reference_clip.quaternion,reference_clip.joints], axis=-1)
# ref_data = ioh5.load('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/ref_multiclip_qpos.h5')

In [49]:

def find_onsets_offsets(signal, threshold=0.5):
    """
    Finds the onsets and offsets of a square signal.

    Args:
        signal: The input signal (numpy array).
        threshold: The threshold for determining onsets and offsets (default: 0.5).

    Returns:
        onsets: A list of indices corresponding to the onsets.
        offsets: A list of indices corresponding to the offsets.
    """

    onsets = []
    offsets = []

    state = 0  # 0: low, 1: high

    for i, value in enumerate(signal):
        if state == 0 and value > threshold:
            onsets.append(i)
            state = 1
        elif state == 1 and value < threshold:
            offsets.append(i)
            state = 0

    return onsets, offsets


In [None]:
policy_data.keys(), amp_data.keys()

In [None]:

amp_force = amp_data['sensordata'].reshape(1,amp_data['sensordata'].shape[1],-1,3)
ctrl_force = policy_data['sensordata'].reshape(17,policy_data['sensordata'].shape[1],-1,3)

amp_force = np.array(amp_force)
ctrl_force = np.array(ctrl_force)

amp_force[np.abs(amp_force)<1e-4]=0
amp_force_clip = amp_force[:,:501].copy()
ctrl_force[np.abs(ctrl_force)<1e-4]=0
ctrl_force_clip = ctrl_force[:,500:].copy()
amp_force_clip.shape, ctrl_force_clip.shape # nclips x timesteps x end_eff x xyz

In [52]:
N = 10
amp_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='same'), axis=1, arr=amp_force_clip)
amp_binary = np.zeros_like(amp_smooth)
amp_binary[np.abs(amp_smooth)>0.01] = 1
ctrl_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='same'), axis=1, arr=ctrl_force_clip)
ctrl_binary = np.zeros_like(ctrl_smooth)
ctrl_binary[np.abs(ctrl_smooth)>0.01] = 1

In [53]:

def calculate_peakforces(sdata_binary,sdata):
    onsets_all,offsets_all = [],[]
    for clip_idx in range(sdata_binary.shape[0]):
        clip_onsets, clip_offsets = [],[]
        for end_eff_idx in range(sdata_binary.shape[2]):
            dim_onsets, dim_offsets = [],[]
            for dim in range(sdata_binary.shape[3]):
                onsets, offsets = find_onsets_offsets(sdata_binary[clip_idx,:,end_eff_idx,dim])
                dim_onsets.append(onsets)
                dim_offsets.append(offsets)
            clip_onsets.append(dim_onsets)
            clip_offsets.append(dim_offsets)
        onsets_all.append(clip_onsets)
        offsets_all.append(clip_offsets)
        
    sdata_square = np.zeros_like(sdata)
    sdata_square.shape
    peak_force_all,peak_force_all_std = [],[]
    for clip_idx,(clip_onsets,clip_offsets) in enumerate(zip(onsets_all,offsets_all)):
        peak_force_end_eff,peak_force_end_eff_std = [],[]
        for end_eff_idx,(end_eff_onsets,end_eff_offsets) in enumerate(zip(clip_onsets,clip_offsets)):
            peak_force_dim, peak_force_std = [], []
            for dim,(dim_on,dim_off) in enumerate(zip(end_eff_onsets,end_eff_offsets)):
                peak_force = []
                for on,off in zip(dim_on,dim_off):
                    if off-on<10:
                        # print(on,off)
                        continue
                    else:
                        sign = np.sign(np.mean(sdata[clip_idx,on:off,end_eff_idx,dim]))
                        sdata_square[clip_idx,on:off,end_eff_idx,dim] = sign * np.nanmean(np.abs(sdata[clip_idx,on:off,end_eff_idx,dim]))
                        peak_force.append(np.nanmean(np.abs(sdata[clip_idx,on:off,end_eff_idx,dim])))
                peak_force_dim.append(np.nanmean(peak_force))
                peak_force_std.append(np.nanstd(peak_force))
            peak_force_end_eff.append(peak_force_dim)
            peak_force_end_eff_std.append(peak_force_std)
        peak_force_all.append(peak_force_end_eff)
        peak_force_all_std.append(peak_force_end_eff_std)
    peak_force_all = np.stack(peak_force_all)
    peak_force_all_std = np.stack(peak_force_all_std)
    return peak_force_all, peak_force_all_std, sdata_square

In [None]:
peak_force_amp,peak_force_amp_std, amp_square = calculate_peakforces(amp_binary,amp_force_clip)
peak_force_ctrl,peak_force_ctrl_std, ctrl_square = calculate_peakforces(ctrl_binary,ctrl_force_clip)
amp_mean_peak_force = np.nanmean(peak_force_amp,axis=0)
amp_std_peak_force = np.nanstd(peak_force_amp_std,axis=(0))
ctrl_mean_peak_force = np.nanmean(peak_force_ctrl,axis=0)
ctrl_std_peak_force = np.nanstd(peak_force_ctrl_std,axis=(0))
amp_std_peak_force.shape,ctrl_std_peak_force.shape

In [55]:
end_eff_ctrl = [
'claw T1left',
'claw T1right',
'claw T2left',
'claw T2right',
'claw T3left',
'claw T3right',
]
end_eff_amp = [
'claw T1 right',
'claw T2 left',
'claw T2 right',
'claw T3 left',
'claw T3 right',
]

In [56]:
# Create a masked array using the binary mask
masked_data_ctrl = np.ma.masked_array(np.abs(ctrl_force_clip), mask=~(ctrl_binary.astype(bool))) 
masked_data_amp = np.ma.masked_array(np.abs(amp_force_clip), mask=~(amp_binary.astype(bool))) 

# Calculate the average along the specified axis
ctr_mean_avg_force = np.ma.mean(masked_data_ctrl, axis=(0,1))
ctr_mean_std_force = np.ma.std(masked_data_ctrl, axis=(0,1))
amp_mean_avg_force = np.ma.mean(masked_data_amp, axis=(0,1))
amp_mean_std_force = np.ma.std(masked_data_amp, axis=(0,1))


In [None]:
fontsize=12
end_eff_idx = 1
fig, axs = plt.subplots(1,1,figsize=(2,2))
# axs = axs.flatten()
ax = axs
ax.bar(x=np.array([1,2,3]), height = 10*ctr_mean_avg_force[end_eff_idx+1,:],width=0.25, yerr=10*ctr_mean_std_force[end_eff_idx+1,:], color='#5fc1ffff')
ax.bar(x=np.array([1.25,2.25,3.25]), height = 10*amp_mean_avg_force[end_eff_idx,:],width=0.25, color='r')
ax.set_xticks([1,2,3])
# ax.set_yticks([0,10,20,30])
ax.set_xticklabels(['x','y','z'])
ax.set_ylabel(f'{end_eff_amp[end_eff_idx]} average \n force ($\mu$N)')
ax.legend(['Ctrl','Amp'],frameon=False,fontsize=fontsize,loc='upper right',bbox_to_anchor=(.4,1.0),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=1,columnspacing=.1)
ax.axhline(.85, color='k', linestyle='--')
# fig.savefig(fig_dir/f'average_force_comparison.png', dpi=300)

In [None]:
np.nanmean(((amp_mean_avg_force - ctr_mean_avg_force[1:])/ctr_mean_avg_force[1:])[:,2],axis=0)

In [None]:
all_difs = []
for end_eff_idx in range(len(end_eff_amp)):
    print(np.nanmean((10*ctrl_mean_peak_force[end_eff_idx+1,-1] - 10*amp_mean_peak_force[end_eff_idx,-1])/10*ctrl_mean_peak_force[end_eff_idx+1,-1]))
    all_difs.append(np.nanmean((10*ctrl_mean_peak_force[end_eff_idx+1,-1] - 10*amp_mean_peak_force[end_eff_idx,-1])/10*ctrl_mean_peak_force[end_eff_idx+1,-1]))
np.mean(all_difs)

In [None]:
fontsize=12
end_eff_idx = 0
fig, axs = plt.subplots(3,2,figsize=(6,6))
axs = axs.flatten()
ax = axs[0]
ax.bar(x=np.array([1,2,3]), height = 10*ctrl_mean_peak_force[end_eff_idx,:],width=0.25, yerr=10*ctrl_std_peak_force[end_eff_idx,:], color='#5fc1ffff')
ax.set_xticks([1,2,3])
# ax.set_yticks([0,10,20,30])
ax.set_xticklabels(['x','y','z'])
ax.set_ylabel(f'{end_eff_amp[end_eff_idx]} \n peak force ($\mu$N)')
ax.axhline(.7, color='k', linestyle='--')

for end_eff_idx in range(len(end_eff_amp)):
    ax = axs[end_eff_idx+1]
    ax.bar(x=np.array([1,2,3]), height = 10*ctrl_mean_peak_force[end_eff_idx+1,:],width=0.25, yerr=10*ctrl_std_peak_force[end_eff_idx+1,:], color='#5fc1ffff')
    ax.bar(x=np.array([1.25,2.25,3.25]), height = 10*amp_mean_peak_force[end_eff_idx,:],width=0.25, yerr=10*amp_std_peak_force[end_eff_idx,:], color='r')
    ax.set_xticks([1,2,3])
    # ax.set_yticks([0,10,20,30])
    ax.set_xticklabels(['x','y','z'])
    if (end_eff_idx==0):
        ax.axhline(.7, color='k', linestyle='--')
    if (end_eff_idx==1) | (end_eff_idx==2):
        ax.axhline(.85, color='k', linestyle='--')
    ax.set_ylabel(f'{end_eff_amp[end_eff_idx]} \n peak force ($\mu$N)')
# ax.set_title(f'Peak force for {end_eff[end_eff_idx+1]}')
ax.legend(['Ctrl','Amp'],frameon=False,fontsize=fontsize,loc='upper right',bbox_to_anchor=(.6,1.0),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=1,columnspacing=.1)
fig.tight_layout()
# fig.savefig(fig_dir/f'peak_force_comparison.png', dpi=300)

In [None]:
joint_idxs = env._joint_idxs
joint_names = cfg.dataset.env_args.joint_names

for n, joint in enumerate(joint_names):
   print(f'{n}: {joint}')

In [None]:
clip_idx = 0
joints_policy = policy_data['qposes'][clip_idx,:, joint_idxs].T
joints_ref = ref_data[clip_idx,:, joint_idxs].T
# ground_force_clip = ground_force[clip_idx,:, 0].T
fontsize = 12
fig, axs = plt.subplots(1, 1, figsize=(3, 2))
# for i in range(joints_policy.shape[1]):
i = 9
ax = axs
ax.plot(np.linspace(0,2,joints_ref.shape[0]),np.rad2deg(joints_ref[:, i]),c='k', label=f'Data', lw=2)
ax.plot(np.linspace(0,2,joints_policy.shape[0]),np.rad2deg(joints_policy[:, i]),c='#5fc1ffff', label=f'Policy',lw=1)
ax.legend(frameon=False,fontsize=fontsize,loc='upper left',bbox_to_anchor=(.01,1.3),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=2,columnspacing=.1)
ax.set_yticks([0,20,40])
ax.set_xticks([0,1,2])
ax.set_xlabel('time (s)')
ax.set_ylabel('R1 femur-tibia \n angle (deg)')
# ax.set_title(f'{joint_names[i]}')


# N=10
# ctrl_force = policy_data['sensordata'].reshape(17,policy_data['sensordata'].shape[1],-1,3)
# ctrl_force = np.array(ctrl_force)
# ctrl_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=ctrl_force)
# amp_force = amp_data['sensordata'].reshape(1,amp_data['sensordata'].shape[1],-1,3)
# amp_force = np.array(amp_force)
# amp_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=amp_force)

# ax = axs[1]
# ax.set_xticks([0,1,2])
# ax.plot(np.linspace(0,2,ctrl_smooth.shape[1]),np.abs(ctrl_smooth[clip_idx,:,2,2]), label=f'ref')
# ax.plot(np.linspace(0,2,ctrl_binary.shape[1]),np.abs(ctrl_binary[0,:ctrl_smooth.shape[1],1,2]), label=f'ref',zorder=0)
# ax.plot(joints_policy[:, i], label=f'policy')
# # ax.legend()
# ax.set_xlabel('timesteps')
# ax.set_ylabel('joint angle (rad)')
# ax.set_title(f'Joint: {joint_names[i]}')
fig.tight_layout()
# fig.savefig(fig_dir / 'joints_angles.svg', dpi=300)

In [62]:
xy_z_forces = np.stack([np.mean(ctrl_mean_peak_force[:,:2],axis=1),ctrl_mean_peak_force[:,-1]],axis=-1)
xy_z_forces_std = np.stack([np.mean(ctrl_std_peak_force[:,:2],axis=1),ctrl_std_peak_force[:,-1]],axis=-1)
xy_z_amp_forces = np.stack([np.mean(amp_mean_peak_force[:,:2],axis=1),amp_mean_peak_force[:,-1]],axis=-1)
xy_z_amp_forces_std = np.stack([np.mean(amp_std_peak_force[:,:2],axis=1),amp_std_peak_force[:,-1]],axis=-1)

In [None]:
clip_idx = 1
N=10
fontsize=12
ctrl_force = policy_data['sensordata'].reshape(17,policy_data['sensordata'].shape[1],-1,3)
ctrl_force = np.array(ctrl_force)
ctrl_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=ctrl_force)
amp_force = amp_data['sensordata'].reshape(1,amp_data['sensordata'].shape[1],-1,3)
amp_force = np.array(amp_force)
amp_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=amp_force)

fig, axs = plt.subplots(2, 1, figsize=(2.5, 4))
ax = axs[0]
end_eff_idx = 1
# ax.plot(np.linspace(0,1,ctrl_smooth.shape[1]),10*np.abs(ctrl_smooth[clip_idx,:,2,0]), label=f'ctrl', c='r')
ax.plot(np.linspace(0,1,ctrl_square.shape[1]),10*np.abs(ctrl_square[clip_idx,:,end_eff_idx,0]), label=f'ctrl', c='#5fc1ffff')
# ax.plot(np.linspace(0,2,amp_smooth.shape[1]),10*np.abs(amp_smooth[clip_idx,:,2,0]), label=f'amp', c='b')
ax.set_xticks([0,0.5,1])
ax.set_yticks([0,0.5,1])
ax.axhline(0.7, color='k', linestyle='--')
ax.fill_between(np.linspace(0,1,ctrl_smooth.shape[1]),.7+.1,.7-.1, color='k', alpha=0.2)
# ax.legend()
ax.set_xlabel('time (s)', fontsize=fontsize)
ax.set_ylabel('R1 average \n force ($\mu$N)', fontsize=fontsize)

end_eff_amp = [
'R1',
'L2',
'R2',
'L3',
'R3',
]
fontsize=12
end_eff_idx = 0
# axs = axs.flatten()
ax = axs[1]
ax.bar(x=np.array([1,1.8]), height = 10*np.nanmean(xy_z_forces,axis=0),width=0.25, yerr=10*np.nanmean(xy_z_forces_std,axis=0), color='#5fc1ffff')
ax.bar(x=np.array([1.25,2.05]), height = 10*np.nanmean(xy_z_amp_forces,axis=0),width=0.25, yerr=10*np.nanmean(xy_z_amp_forces_std,axis=0), color='r')
ax.set_xticks([1.125,1.9])
ax.set_yticks([0,1,2,3])
ax.set_ylim([0,3])
ax.set_xticklabels(['xy','z'])
ax.tick_params(axis='x', labelsize=fontsize+2)
ax.set_ylabel(f'Average \nforce ($\mu$N)', fontsize=fontsize)
ax.legend(['Ctrl','Amp'],frameon=False,fontsize=fontsize,loc='upper left',bbox_to_anchor=(-.1,1.),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=1,columnspacing=.1)
ax.axhline(.7, color='k', linestyle='--')
ax.fill_between([.85,2.2],.7+.1,.7-.1, color='k', alpha=0.2)

plt.tight_layout()
# fig.savefig(fig_dir / 'forces.svg', dpi=300, format='svg')

In [None]:
ref_data.shape

In [None]:
##### amputation data plotting #####
clip_idx = 0
joints_policy = amp_data['qposes'][clip_idx,:500, joint_idxs].T
joints_ref = ref_data[clip_idx,:500, joint_idxs].T
# ground_force_clip = ground_force[clip_idx,:, 0].T
fontsize = 12
fig, axs = plt.subplots(2, 1, figsize=(5, 4))
# for i in range(joints_policy.shape[1]):
i = 9
ax = axs[0]
ax.plot(np.linspace(0,1,joints_ref.shape[0]),np.rad2deg(joints_ref[:, i]),c='k', label=f'Data', lw=2)
ax.plot(np.linspace(0,1,joints_policy.shape[0]),np.rad2deg(joints_policy[:, i]),c='r', label=f'Policy',lw=1)
ax.legend(frameon=False,fontsize=fontsize,loc='upper right',bbox_to_anchor=(.9,1.3),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=2,columnspacing=.1)
ax.set_yticks([0,20,40])
ax.set_xticks([0,.5,1])
ax.set_xlabel('time (s)')
ax.set_ylabel('R1 femur-tibia \n angle (deg)')
# ax.set_title(f'{joint_names[i]}')
# i=12
N=1
ctrl_force = policy_data['sensordata'].reshape(17,policy_data['sensordata'].shape[1],-1,3)
ctrl_force = np.array(ctrl_force)[:,501:]
ctrl_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=ctrl_force)
amp_force = amp_data['sensordata'].reshape(1,amp_data['sensordata'].shape[1],-1,3)
amp_force = np.array(amp_force)[:,:500]
amp_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=1, arr=amp_force)

ax = axs[1]
ax.set_xticks([0,.5,1])
ax.plot(np.linspace(0,1,ctrl_smooth.shape[1]),np.abs(10*ctrl_smooth[clip_idx,:,2,0]), label=f'ref')
ax.plot(np.linspace(0,1,ctrl_smooth.shape[1]),np.abs(10*amp_smooth[0,:ctrl_smooth.shape[1],1,0]), label=f'ref',zorder=0)
# ax.plot(joints_policy[:, i], label=f'policy')
# # ax.legend()
# ax.set_xlabel('timesteps')
# ax.set_ylabel('joint angle (rad)')
# ax.set_title(f'Joint: {joint_names[i]}')
fig.tight_layout()
# fig.savefig(fig_dir / 'joints_angles.svg', dpi=300)

In [55]:
ctrl_joint_ranges = (np.rad2deg(np.max(policy_data['qposes'][:,:,7:],axis=(0,1))) - np.rad2deg(np.min(policy_data['qposes'][:,:,7:],axis=(0,1))))[6:]
amp_joint_ranges = np.rad2deg(np.max(amp_data['qposes'][:,:,7:],axis=(0,1))) - np.rad2deg(np.min(amp_data['qposes'][:,:,7:],axis=(0,1)))

In [None]:
fig,axs = plt.subplots(1,1,figsize=(5,3))
ax=axs
ax.plot(ctrl_joint_ranges)
ax.plot(amp_joint_ranges)
ax.set_xticks(np.arange(0,len(joint_names)))
ax.set_xticklabels(joint_names,rotation=90)

In [None]:
plt.plot(policy_data['qposes'][0,:,14])
plt.plot(policy_data['qposes'][0,:,14+6])

In [None]:
plt.plot(amp_data['qposes'][0,:1000,7])
plt.plot(amp_data['qposes'][0,:1000,7+6])

In [None]:
fig,ax = plt.subplots(1,1)
ax.pcolormesh(amp_binary[0,:,:,-1])

In [None]:
for n in range(17):
    fig,ax = plt.subplots(1,1)
    ax.pcolormesh(ctrl_binary[n,:,1f:,-1])

In [112]:
joint_names2 = [
    'coxa flexion \n T2 left',
    'coxa twist \n T2 left',
    'femur twist \n T2 left',
    'femur \n T2 left',
    'tibia \n T2 left',
    'tarsus \n T2 left',
]

In [None]:
ctrl_smooth.shape, amp_smooth.shape

In [68]:
##### amputation data plotting #####
clip_idx = 0

N=1
ctrl_force = policy_data['sensordata'].reshape(17,policy_data['sensordata'].shape[1],-1,3)
ctrl_force = np.array(ctrl_force)[:,500:]
ctrl_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='same'), axis=1, arr=ctrl_force)
amp_force = amp_data['sensordata'].reshape(1,amp_data['sensordata'].shape[1],-1,3)
amp_force = np.array(amp_force)[:,:500]
amp_smooth = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='same'), axis=1, arr=amp_force)


In [None]:
ctrl_binary.shape,ctrl_force_clip.shape,

In [None]:
peak_force_ctrl,peak_force_ctrl_std, ctrl_square = calculate_peakforces(ctrl_binary,ctrl_force_clip)
peak_force_amp,peak_force_amp_std, amp_square = calculate_peakforces(amp_binary,amp_force_clip)
ctrl_square.shape, amp_square.shape

In [None]:
times.shape, ctrl_smooth.shape, amp_smooth.shape

In [None]:
clip_idx = 0
# ground_force_clip = ground_force[clip_idx,:, 0].T
fontsize = 12
fig, axs = plt.subplots(2, 1, figsize=(5, 4))
times = np.linspace(0,1,ctrl_smooth.shape[1])
end_eff_idx = 0
i = 9
ax = axs[0]
ax.plot(times,np.abs(10*ctrl_smooth[clip_idx,:,end_eff_idx+1,1]), label=f'Ctrl')
ax.plot(times,np.abs(10*ctrl_square[clip_idx,:,end_eff_idx+1,1]), label=f'Ctrl')

ax.legend(frameon=False,fontsize=fontsize,loc='upper right',bbox_to_anchor=(.9,1.3),labelcolor='linecolor',handlelength=0,handleheight=0,ncols=2,columnspacing=.1)
# ax.set_yticks([0,20,40])
# ax.set_xticks([0,.5,1])
ax.set_xlabel('time (s)')
# ax.set_ylabel('R1 femur-tibia \n angle (deg)')

clip_idx = 0
ax = axs[1]
ax.set_xticks([0,.5,1])
ax.plot(times[:-1],np.abs(10*amp_smooth[clip_idx,:,end_eff_idx,0]), label=f'Amp')
ax.plot(times,np.abs(10*amp_square[clip_idx,:,end_eff_idx,0]), label=f'Amp')
# ax.plot(joints_policy[:, i], label=f'policy')
# # ax.legend()
# ax.set_xlabel('timesteps')
# ax.set_ylabel('joint angle (rad)')
# ax.set_title(f'Joint: {joint_names[i]}')
fig.tight_layout()
# fig.savefig(fig_dir / 'joints_angles.svg', dpi=300)

In [None]:
np.mean(mean_error), np.std(mean_error)

In [None]:
mean_error = np.rad2deg(np.mean(np.abs(joints_ref[:joints_ref.shape[0],:] - joints_policy[:joints_ref.shape[0],:]),axis=0))#[12:18]
std_error = np.std(np.rad2deg(np.abs(joints_ref[:joints_ref.shape[0],:] - joints_policy[:joints_ref.shape[0],:])),axis=0)#[12:18]
fig, axs = plt.subplots(1,1,figsize=(6, 5))
ax = axs
xs = 2*np.arange(mean_error.shape[0])
ax.errorbar(x=xs,y=mean_error,yerr=std_error,fmt='o',c='k',markersize=10,elinewidth=2,capsize=5)
ax.set_xticks(xs)
ax.set_xticklabels(joint_names,rotation=90, fontsize=fontsize)
ax.set_ylabel('MAE (deg)', fontsize=fontsize)
ax.set_yticks([0,2,4])
ax.tick_params(axis='x', labelsize=fontsize)
ax.tick_params(axis='y', labelsize=fontsize)
ax.set_xlim(-.5,xs[-1]+.5)
plt.tight_layout()
plt.show()
# fig.savefig(fig_dir/'mean_abs_error.svg',bbox_inches='tight',dpi=300,transparent=True)

In [50]:
middle_leg_force = [
    [
        0.095739335,
        0.123029839,
        0.173707312,
        0.113510692,
        0.145081388,
        0.093068495,
        0.105155757,
    ],
    [
        0.098273208,
        0.064990435,
        0.162030948,
    ],
    [
        0.153778738,
        0.165660551,
        0.147786469,
        0.092143973,
    ],
    [
        0.143848693,
        0.169666811,
    ],
    [
        0.047150595,
        0.03567968,
        0.072112675,
        0.038008105,
        0.069715767,
        0.111216509,
    ]
]

front_leg_force = [
    [
        0.155970196,
        0.117311502,
        0.082419377,
        0.130802667,
        0.093787567,
        0.108100529,
    ]
]

In [None]:
media.show_video(frames, fps=50)

In [64]:
sensordata = np.array(amp_data['sensordata'])
sensordata= sensordata.reshape(1,sensordata.shape[1],-1,3)

In [None]:
plt.plot(sensordata[0,:,1,2])
# plt.plot(sensordata[:,1,0])

In [8]:
ref_multiclip = ioh5.load('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/ref_multiclip_qpos.h5')
torque_multiclip = ioh5.load('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Torque_control_multiclip.h5')

In [24]:
all_sensor = np.stack([torque_multiclip[f'clip{n:02d}']['sensordata'] for n in range(env._n_clips)]).reshape(env._n_clips,1000,-1,3)

In [17]:
# 1 (mg cm) / (s^2) = 1.0 × 10-8 newtons

In [20]:
# sensordata = [state.pipeline_state.sensordata for state in rollout]

In [None]:
N = 5
# sensordata = [state.pipeline_state.sensordata for state in rollout]
sdata = 1e-8*(np.stack(sensordata)[:,:-6].reshape(-1,6,3)) # Time x end_eff x xyz, x=forward
sdata = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=0, arr=sdata)
end_eff = [
'claw_T1_left',
'claw_T1_right',
'claw_T2_left',
'claw_T2_right',
'claw_T3_left',
'claw_T3_right',
]
fig, axs = plt.subplots(3, 2, figsize=(10, 10), sharey=True)
axs = axs.flatten()
t = 0
dt = 10000
for n in range(len(end_eff)):
    ax = axs[n]
    ax.plot(sdata[t:t+dt,n,0], label='x')
    ax.plot(sdata[t:t+dt,n,1], label='y')
    ax.plot(sdata[t:t+dt,n,2], label='z')
    ax.set_title(end_eff[n])
    ax.set_xlabel('timesteps')
    ax.set_ylabel('force (N)')
axs[0].legend()
plt.tight_layout()
# fig.savefig(fig_dir/'claw_force.png', dpi=300)

In [None]:
N = 50
sdata = 1e-8*(np.stack(sensordata).reshape(-1,6,3)) # Time x end_eff x xyz, x=forward
sdata = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=0, arr=sdata)

fig, axs = plt.subplots(3, 2, figsize=(10, 10), sharey=True)
axs = axs.flatten()
for n in range(len(end_eff)):
    ax = axs[n]
    ax.plot(sdata[:,n,0])
    ax.plot(sdata[:,n,1])
    ax.plot(sdata[:,n,2])
# plt.plot(sdata[:,:,2])

In [None]:
import pandas as pd
fictrac_data = pd.read_csv('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Amputation/09302024_fly2_0_R1C1_fictrac.csv')

In [None]:
fictrac_y_mm = (fictrac_data['fictrac_int_y_mm']/10).values
fictrac_x_mm = (fictrac_data['fictrac_int_x_mm']/10).values
fictrac_forward_mm = (fictrac_data['fictrac_int_forward_mm']/10).values

In [142]:

data_path = '/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Amputation/Fruitfly_fit_amp2.p'
with open(data_path, "rb") as file:
    d = pickle.load(file)
    qposes = np.array(d["qpos"])
    kp_data = np.array(d["kp_data"])
    kp_names = d["kp_names"]
    offsets = d["offsets"]
qposes = jp.array(qposes)

In [143]:
qpos_all = []
# for n in range(fictrac_y_mm.shape[0]):
dpos = jp.zeros((fictrac_y_mm.shape[0],3))
dpos = dpos.at[1:,0].set(jp.cumsum(jp.diff(fictrac_y_mm)))
# dpos = dpos.at[1:,0].set(jp.cumsum(jp.diff(fictrac_x_mm)))
qposes = qposes.at[:,:3].set(qposes[:,:3]+dpos)
qposes = qposes.at[:,3].set(11*np.pi/6)
qposes = qposes.at[:,4].set(0)
qposes = qposes.at[:,5].set(0)
    # qpos_all.append(qposes[:,:3]+dpos)
# qpos_all = jp.stack(qpos_all)
# qposes = qposes.at[:,:3].set(qpos_all)

In [75]:
positions = qposes[1:, 0]

In [76]:
qposes = qposes.at[1:,0].set(positions + np.cumsum(np.diff(fictrac_forward_mm)))
qposes = qposes.at[:,3].set(3.1415)

In [144]:
model_path = ("/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fast_amp.xml")

spec = mujoco.MjSpec()
spec = spec.from_file(model_path)
thorax0 = spec.find_body("thorax")
mj_model = spec.compile()
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = env.sys.mj_model.opt.timestep
mj_data = mujoco.MjData(mj_model)

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]

# save rendering and log to wandb
mujoco.mj_kinematics(mj_model, mj_data)
# renderer = mujoco.Renderer(mj_model, height=512, width=512)
# qposes_rollout = mocap_qpos.reshape(581,1800,-1)[0,:500]
# qposes_rollout = qposes_rollout.at[:,2].set(.05)
# qposes_rollout = qposes_rollout.at[:,3:7].set(0)
# qposes = qposes.at[:,2].set(.1)
# qposes = qposes.at[:,3:7].set(0)
frames = []
# render while stepping using mujoco
with mujoco.Renderer(mj_model, height=480, width=480) as renderer:
    for qpos1 in qposes[:200]:
        mj_data.qpos = qpos1
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera='track1', scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)

In [None]:
media.show_video(frames, fps=50)

In [148]:

import jax
from jax import jit
from jax import numpy as jp
from flax import struct

from dm_control import mjcf
from dm_control.locomotion.walkers import rescale

import mujoco
from mujoco import mjx
from mujoco.mjx._src import smooth

import preprocessing.transformations as tr

from collections import defaultdict
from typing import Text, Union, List
import h5py
import pickle
from preprocessing.mjx_preprocess import process_clip

In [None]:
root = mjcf.from_path(model_path)

# rescale a rodent model.

mj_model = mjcf.Physics.from_mjcf_model(root).model.ptr
mj_data = mujoco.MjData(mj_model)

# Initialize MuJoCo model and data structures & place into GPU
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

ref_clip = process_clip(qposes, mjx_model, mjx_data, max_qvel=20, dt=1/300)

In [157]:
# reference_path = Path(cfg.paths.data_dir)/ f"clips/all_clips_interp.p"
reference_path = Path(cfg.paths.data_dir)/ f"clips/amp_data.p"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "wb") as file:
    # Use pickle.dump() to save the data to the file
    pickle.dump(ref_clip, file)

In [None]:
ref_clip.joints.shape

In [None]:
plt.plot(fictrac_x_mm,fictrac_y_mm)

In [None]:
# xml_path = Path('/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fast_amp.xml')
xml_path = '/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_viz_only.xml'
spec = mujoco.MjSpec()
spec = spec.from_file(xml_path)
mj_model = spec.compile()
mj_data = mujoco.MjData(mj_model)

camera = 'back'

scene_option = mujoco.MjvOption()
scene_option.geomgroup[:] = [1, 1, 0, 0, 0, 0]
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = False
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_ACTUATOR] = False
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_LIGHT] = False
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONVEXHULL] = False
# scene_option.flags[mujoco.mjtRndFlag.mjRND_SHADOW] = False
# scene_option.flags[mujoco.mjtRndFlag.mjRND_REFLECTION] = False
# scene_option.flags[mujoco.mjtRndFlag.mjRND_SKYBOX] = False
# scene_option.flags[mujoco.mjtRndFlag.mjRND_FOG] = False
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_STATIC] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True
# mj_data.qpos = np.concatenate([qposes_stac[n] for n in range(0,25,5)])
mujoco.mj_forward(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model, height=420, width=640)


# mj_data.mocap_pos = mod_frames[0].copy()
mujoco.mj_forward(mj_model, mj_data)

renderer.update_scene(mj_data, camera='side', scene_option=scene_option)
pixels = renderer.render()
Image.fromarray(pixels)
# renderer.update_scene(mj_data, camera='side', scene_option=scene_option)
# pixels = renderer.render()
# Image.fromarray(pixels)
# im = Image.fromarray(pixels)
# im.save(fig_dir/f'example_mocap_{camera}.png',format='png',dpi=(300,300))
# Image.fromarray(pixels)
# im.save('example_mocap.png',format='png')
# media.show_image(pixels, title='amp_example.png',)

# Model Plotting

In [8]:
from PIL import Image

In [None]:
# xml_path = Path('/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fast_amp.xml')
xml_path = '/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_viz_only.xml'
spec = mujoco.MjSpec()
spec = spec.from_file(xml_path)
mj_model = spec.compile()
mj_data = mujoco.MjData(mj_model)

camera = 'hero'

scene_option = mujoco.MjvOption()
scene_option.geomgroup[:] = [1, 1, 0, 0, 0, 0]
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_ACTUATOR] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_LIGHT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONVEXHULL] = False
scene_option.flags[mujoco.mjtRndFlag.mjRND_SHADOW] = False
scene_option.flags[mujoco.mjtRndFlag.mjRND_REFLECTION] = False
scene_option.flags[mujoco.mjtRndFlag.mjRND_SKYBOX] = False
scene_option.flags[mujoco.mjtRndFlag.mjRND_FOG] = False
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_STATIC] = False
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True
# mj_data.qpos = np.concatenate([qposes_stac[n] for n in range(0,25,5)])
mujoco.mj_forward(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model, height=420, width=640)


# mj_data.mocap_pos = mod_frames[0].copy()
mujoco.mj_forward(mj_model, mj_data)

renderer.update_scene(mj_data, camera=camera, scene_option=scene_option)
pixels = renderer.render()

im = Image.fromarray(pixels)
im.save(fig_dir/f'example_mocap_{camera}.png',format='png',dpi=(300,300))
Image.fromarray(pixels)
# im.save('example_mocap.png',format='png')
# media.show_image(pixels, title='amp_example.png',)

In [None]:
xml_path = task_path / 'task_viz.xml'
# xml_path = task_path.parent / 'fruitfly_force_pair.xml'
# xml_path = task_path.parent / 'fruitfly_force_seq.xml'
# xml_path = task_path.parent / 'fruitfly_force_fast.xml'

mj_model = mujoco.MjModel.from_xml_path(xml_path.as_posix())

mj_data = mujoco.MjData(mj_model)

mocap_names = [mj_model.site(i).name for i in range(mj_model.nsite) if 'mocap' in mj_model.site(i).name]
mocap_id = [mj_model.site(i).id for i in range(mj_model.nsite) if 'mocap' in mj_model.site(i).name]
# for id in mocap_id:
#     mj_model.site(id).rgba = [0,1,0,1]
site_names = [mj_model.site(i).name for i in range(mj_model.nsite) if '-1' in mj_model.site(i).name]
site_id = [mj_model.site(i).id for i in range(mj_model.nsite) if '-1' in mj_model.site(i).name]
for id in site_id:
    mj_model.site(id).rgba = [1,0,0,1]

# geom_names = [mj_model.geom(i).name for i in range(mj_model.ngeom) if 'T3_left' in mj_model.geom(i).name]
# geom_ids = [mj_model.geom(i).id for i in range(mj_model.ngeom) if 'T3_left' in mj_model.geom(i).name]

# for id in geom_ids[5:]:
#     mj_model.geom(id).group = 4

camera_id = 0

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 0, 1, 0]
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_ACTUATOR] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_LIGHT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONVEXHULL] = True
scene_option.flags[mujoco.mjtRndFlag.mjRND_SHADOW] = False
scene_option.flags[mujoco.mjtRndFlag.mjRND_REFLECTION] = False
scene_option.flags[mujoco.mjtRndFlag.mjRND_SKYBOX] = False
scene_option.flags[mujoco.mjtRndFlag.mjRND_FOG] = False
# scene_option.flags[mujoco.mjtVisFlag.mjVIS_STATIC] = False

# mj_data.qpos = np.concatenate([qposes_stac[n] for n in range(0,2)])
# mj_data.qpos = qposes_stac[0]
mujoco.mj_kinematics(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model, height=420, width=550)


# mj_data.mocap_pos = mod_frames[0].copy()
mujoco.mj_forward(mj_model, mj_data)

renderer.update_scene(mj_data, camera=camera_id, scene_option=scene_option)
pixels = renderer.render()
PIL.Image.fromarray(pixels)
# im = PIL.Image.fromarray(pixels)
# im.save(fig_path/'init_camera{}.png'.format(camera_id),format='png')