# Imports set up rendering

In [2]:
#@title Check if MuJoCo installation was successful

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
# %env MUJOCO_GL=egl
os.environ['MUJOCO_GL'] = 'egl'
os.environ['PYOPENGL_PLATFORM'] = 'egl'
try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=True "
    # "--xla_gpu_enable_async_collectives=true "
    # "--xla_gpu_enable_latency_hiding_scheduler=true "
    # "--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ['CUDA_VISIBLE_DEVICES'] = '1' # Use GPU 1

Thu Oct 17 23:20:53 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               On  |   00000000:41:00.0 Off |                  Off |
| 30%   44C    P8             15W /  300W |      11MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A6000               On  |   00

In [3]:
%load_ext autoreload
%autoreload 2
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use GPU 1
import functools
import jax
# jax.config.update("jax_enable_x64", True)

n_gpus = jax.device_count(backend="gpu")
print(f"Using {n_gpus} GPUs")
from typing import Dict
from brax import envs
import mujoco
import pickle
import warnings
import mediapy as media
import hydra
import jax.numpy as jp

from omegaconf import DictConfig, OmegaConf
from brax.training.agents.ppo import networks as ppo_networks
from custom_brax import custom_ppo as ppo
from custom_brax import custom_wrappers
from custom_brax import custom_ppo_networks
from orbax import checkpoint as ocp
from flax.training import orbax_utils
from preprocessing.mjx_preprocess import process_clip_to_train
from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking
from utils.utils import *
from utils.fly_logging import log_eval_rollout

warnings.filterwarnings("ignore", category=DeprecationWarning)
# jax.config.update("jax_enable_x64", True)

from hydra import initialize, compose
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra


Using 1 GPUs


2024-10-17 23:20:55.171704: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
  self.hub = sentry_sdk.Hub(client)


# Load configs

In [3]:
dataset = "multiclip"
with initialize(version_base=None, config_path="configs"):
    cfg=compose(config_name='config.yaml',overrides= [f"dataset=fly_{dataset}", f"train=train_fly_{dataset}", "paths=walle"],return_hydra_config=True,)
    HydraConfig.instance().set_config(cfg)

In [4]:
for k in cfg.paths.keys():
    if (k != 'user'):
        cfg.paths[k] = Path(cfg.paths[k])
        cfg.paths[k].mkdir(parents=True, exist_ok=True)
env_cfg = cfg.dataset
env_args = cfg.dataset.env_args
reference_path = cfg.paths.data_dir/ f"clips/all_clips_batch_interp.p"
# reference_path = cfg.paths.data_dir/ f"clips/0.p"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)


In [5]:
# ref_clip = {}

# ref_clip['angular_velocity'] = np.array(jp.repeat(reference_clip.angular_velocity[:1],axis=0,repeats=env_args['clip_length']))
# ref_clip['body_positions'] = np.array(jp.repeat(reference_clip.body_positions[:1],axis=0,repeats=env_args['clip_length']))
# ref_clip['body_quaternions'] = np.array(jp.repeat(reference_clip.body_quaternions[:1],axis=0,repeats=env_args['clip_length']))
# ref_clip['joints'] = np.array(jp.repeat(reference_clip.joints[:1],axis=0,repeats=env_args['clip_length']))
# ref_clip['joints_velocity'] = np.array(jp.repeat(reference_clip.joints_velocity[:1],axis=0,repeats=env_args['clip_length']))
# ref_clip['position'] = np.array(jp.repeat(reference_clip.position[:1],axis=0,repeats=env_args['clip_length']))
# ref_clip['quaternion'] = np.array(jp.repeat(reference_clip.quaternion[:1],axis=0,repeats=env_args['clip_length']))
# ref_clip['velocity'] = np.array(jp.repeat(reference_clip.velocity[:1],axis=0,repeats=env_args['clip_length']))
# # ioh5.save(Path(cfg.paths.data_dir)/ f"clips/{env_cfg['clip_idx']}_stand.h5", ref_clip)

# Load env

In [None]:
dataset = 'fly_multiclip'

from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking, _bounded_quat_dist
with initialize(version_base=None, config_path="configs"):
    cfg=compose(config_name='config.yaml',overrides= [f"dataset={dataset}", f"train=train_{dataset}", "paths=walle"],return_hydra_config=True,)
    HydraConfig.instance().set_config(cfg)
    

env_args = cfg.dataset.env_args
envs.register_environment("fly_freejnt_clip", FlyTracking)
envs.register_environment("fly_freejnt_multiclip", FlyMultiClipTracking)
# cfg_load = OmegaConf.load('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21356039/logs/run_config.yaml')
# cfg_load.paths = cfg.paths
env = envs.get_environment(
    cfg.train.env_name,
    reference_clip=reference_clip,
    **cfg.dataset.env_args,
)

In [6]:
rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
    
# rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
# define the jit reset/step functions
jit_reset = jax.jit(rollout_env.reset)
jit_step = jax.jit(rollout_env.step)
state = jit_reset(jax.random.PRNGKey(0))


In [None]:
model_path = ("/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_fastviz.xml")
spec = mujoco.MjSpec()
spec.from_file(model_path)
thorax = spec.find_body("thorax")
first_joint = thorax.first_joint()
# first_joint.delete()
root = spec.compile()
root.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}[cfg.dataset.env_args.solver.lower()]
root.opt.iterations = env_args.iterations
root.opt.ls_iterations = env_args.ls_iterations
root.opt.timestep = env_args.physics_timestep
root.opt.jacobian = 0
data = mujoco.MjData(root)
mujoco.mj_forward(root, data)


n_frames = 500
height = 512
width = 512
frames = []
fps = 1/.002
clip_idx = 3

times = []
sensordata = []
scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]

scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
qpos_all,rollout,ncon_all = [],[],[]
with mujoco.Renderer(root, height, width) as renderer:
    for t in range(n_frames):
        data.qpos = mocap_qpos_tredmill[t]
        # data.qpos = np.concatenate([reference_clip.position[clip_idx][t],reference_clip.quaternion[clip_idx][t],reference_clip.joints[clip_idx][t]])
        while data.time < t/fps:
            mujoco.mj_step(root, data)
            sensordata.append(data.sensordata.copy())
        times.append(data.time)
        renderer.update_scene(data,camera='track1',scene_option=scene_option)
        frame = renderer.render()
        frames.append(frame)
        qpos_all.append(data.qpos.copy())
        ncon_all.append(data.ncon)
        rollout.append(data)

media.show_video(frames, fps=50)


In [None]:
end_eff = [
'claw_T1_left',
'claw_T1_right',
'claw_T2_left',
'claw_T2_right',
'claw_T3_left',
'claw_T3_right',
]
N = 100
sdata = 1e-8*(np.stack(sensordata).reshape(-1,6,3)) # Time x end_eff x xyz, x=forward
sdata = np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=0, arr=sdata)

fig, axs = plt.subplots(3, 2, figsize=(10, 10), sharey=True)
axs = axs.flatten()
for n in range(len(end_eff)):
    ax = axs[n]
    ax.plot(sdata[:,n,0])
    ax.plot(sdata[:,n,1])
    ax.plot(sdata[:,n,2])
# plt.plot(sdata[:,:,2])

# Cleaning Data

In [7]:
import pandas as pd
base_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/')
data_path = base_path / 'combined_wt_berlin_walking_v3.pq'
full_df = pd.read_parquet(data_path, engine='pyarrow')
bout_stats = full_df.groupby(['walking_bout_number','fullfile','Sex'])[['fictrac_delta_rot_lab_y_mms', 'fictrac_delta_rot_lab_z_deg/s']].agg(['mean','min','max','std','count'])
fast = (bout_stats[('fictrac_delta_rot_lab_y_mms','mean')] >= 12) & (bout_stats[('fictrac_delta_rot_lab_y_mms','min')] >= 10)
straight = (bout_stats[('fictrac_delta_rot_lab_z_deg/s','mean')].abs() <= 45) &\
           (bout_stats[('fictrac_delta_rot_lab_z_deg/s','min')] >= -60) &\
           (bout_stats[('fictrac_delta_rot_lab_z_deg/s','max')] <= 60)

legs = ['T1_left', 'T1_right', 'T2_left', 'T2_right', 'T3_left', 'T3_right']
joints = ['coxa', 'femur', 'tibia', 'tarsus']
xpos_geoms = ['coxa', 'femur', 'tibia', 'tarsus', 'claw']
joint_names = [f'{joint}_{leg}' for leg in legs for joint in joints]
xpos_names = [f'{joint}_{leg}' for leg in legs for joint in xpos_geoms]
# physics.named.data.framepos[pos_names]
site_names = [f'tracking[{joint_name}]' for joint_name in xpos_names]

legs_data = ['L1', 'R1', 'L2','R2', 'L3','R3']
joints_data = ['A','B','C','D','E']
coords_data = ['_x','_y','_z']
joint_pos_columns = [leg + joint + coord 
                     for leg in legs_data
                     for joint in joints_data 
                     for coord in coords_data]

straight_bouts = bout_stats[fast & straight].index
straight_bout_num = np.array([i[0] for i in straight_bouts],dtype=int)
mean_straight_bout = (full_df[full_df['walking_bout_number'] == 14574][joint_pos_columns].values.reshape(-1,30, 3)).mean(axis=0)
ref_bout = (full_df[full_df['walking_bout_number'] == 14574][joint_pos_columns].values.reshape(-1,30, 3))
bout_dict = {'walking_bout{:02}'.format(n):{} for n in range(len(straight_bout_num))}

for n, bout_num in enumerate(straight_bout_num):
    bout = full_df[full_df['walking_bout_number'] == bout_num]
    bout_dict['walking_bout{:02}'.format(n)]['orig_xpos'] = bout[joint_pos_columns].values.reshape(-1,30, 3)
    bout_dict['walking_bout{:02}'.format(n)]['orig_xpos'] = bout_dict['walking_bout{:02}'.format(n)]['orig_xpos'] + (mean_straight_bout-np.mean(bout_dict['walking_bout{:02}'.format(n)]['orig_xpos'],axis=0))

In [8]:
stac_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/transform_mocap_fly_freejnt.p')
# Load mocap data from a file.
with open(stac_path, "rb") as file:
    d = pickle.load(file)
    mocap_qpos = jp.array(d["qpos"])

In [9]:
from preprocessing.mjx_preprocess import process_clip, save_reference_clip_to_h5, load_reference_clip_from_h5, ReferenceClip
from jax import vmap
import mujoco.mjx as mjx

In [None]:
[col for col in bout.columns if 'fictrac' in col]

In [35]:
lin_vel_y_cm, lin_vel_x_cm = [], []
for n, bout_num in enumerate(straight_bout_num):
    bout = full_df[full_df['walking_bout_number'] == bout_num]
    lin_vel_y_cm.append(bout['fictrac_delta_rot_lab_y_mms'].values/10)
    lin_vel_x_cm.append(bout['fictrac_delta_rot_lab_x_mms'].values/10)
# lin_vel_y_cm = np.concatenate(lin_vel_y_cm,axis=0)
clip_shape = np.array([clip.shape[0] for clip in lin_vel_y_cm])
clip_shape = np.concatenate([[0],clip_shape])
mocap_qpos_reshaped = []
for n in range(1,len(clip_shape)-1):
    mocap_qpos_reshaped.append(mocap_qpos[np.sum(clip_shape[:n]):np.sum(clip_shape[:n+1])])
mocap_qpos_reshaped.append(mocap_qpos[np.sum(clip_shape[:-1]):])
lin_vel_y_cm[-1] = lin_vel_y_cm[-1][:mocap_qpos_reshaped[-1].shape[0]]
lin_vel_x_cm[-1] = lin_vel_x_cm[-1][:mocap_qpos_reshaped[-1].shape[0]]

In [10]:
##### Process clips #####

# spec = mujoco.MjSpec()
# spec.from_file(cfg.dataset.env_args.mjcf_path)
# mj_model = spec.compile()

# mj_model.opt.solver = {
# "cg": mujoco.mjtSolver.mjSOL_CG,
# "newton": mujoco.mjtSolver.mjSOL_NEWTON,
# }["cg"]
# mj_model.opt.iterations = cfg.dataset.env_args.iterations
# mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
# mj_model.opt.timestep = cfg.dataset.env_args.physics_timestep

# mj_data = mujoco.MjData(mj_model)

# # Initialize MuJoCo model and data structures & place into GPU
# mjx_model = mjx.put_model(mj_model)
# mjx_data = mjx.put_data(mj_model, mj_data)

# all_clips = []
# for n in range(len(mocap_qpos_reshaped)):
#     clip=process_clip(mocap_qpos_reshaped[n],mjx_model,mjx_data,max_qvel=20, dt=1/300)
#     all_clips.append(clip)

In [13]:
clip_names = [f'clip{n:02}' for n in range(len(mocap_qpos_reshaped))]
all_clips = load_reference_clip_from_h5('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/clips/all_clips_raw.h5',clip_names)
# save_reference_clip_to_h5('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/clips/all_clips_raw.h5',clip_names=clip_names,reference_clip=all_clips_reference)

In [44]:
ref_clip = {}
ref_clip['position'] = all_clips.position
ref_clip['quaternion'] = all_clips.quaternion
ref_clip['joints'] = all_clips.joints
ref_clip['body_positions'] = all_clips.body_positions
ref_clip['velocity'] = all_clips.velocity
ref_clip['joints_velocity'] = all_clips.joints_velocity
ref_clip['angular_velocity'] = all_clips.angular_velocity
ref_clip['body_quaternions'] = all_clips.body_quaternions

In [24]:
# reference_path = Path(cfg.paths.data_dir)/ f"clips/all_clips_interp.p"
reference_path = Path(cfg.paths.data_dir)/ f"clips/0.p"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)

In [45]:
qpos_all = []
for n in range(len(all_clips.position)):
    dpos = jp.zeros((all_clips.position[n].shape[0],3))
    dpos = dpos.at[:,0].set(jp.cumsum(lin_vel_y_cm[n])*env.dt)
    dpos = dpos.at[:,1].set(jp.cumsum(lin_vel_x_cm[n])*env.dt)
    
    qpos_all.append(all_clips.position[n]+dpos)

for n in range(len(ref_clip['position'])):
    ref_clip['position'][n] = qpos_all[n]

In [None]:
mocap_hz = 500
clip_len = 601
control_timestep = 0.002
phyiscs_timestep = 0.0002
physics_steps_per_control_step = int(control_timestep/phyiscs_timestep)
max_physics_steps_per_control_step = int(
(1.0 / (mocap_hz * phyiscs_timestep))
)
max_physics_steps_per_control_step

In [49]:
clip_len = all_ref_clip['position'][0].shape[0]

### Interpolate data

In [None]:
n = 0
key = 'position'
ref_clip_interp = {key:[] for key in ref_clip.keys()}

clip_len = ref_clip[key][n].shape[0]
tmax = 1/300 * clip_len # 1/300 is original mocap hz
t = jp.linspace(0,tmax,clip_len)
t_interp = jp.linspace(0,tmax,1000)
ref_clip_interp[key].append(jp.apply_along_axis(lambda fp,x,xp: jp.interp(x,xp,fp), 0, ref_clip[key][n],x=t_interp,xp=t))
print(key, ref_clip_interp[key][n].shape)

In [None]:
ref_clip_interp = {key:[] for key in ref_clip.keys()}
for key,val in ref_clip.items():
    for n in range(len(val)):
        clip_len = ref_clip[key][n].shape[0]
        tmax = 1/300 * clip_len # 1/300 is original mocap hz
        t = jp.linspace(0,tmax,clip_len)
        t_interp = jp.linspace(0,tmax,1000)
        ref_clip_interp[key].append(jp.apply_along_axis(lambda fp,x,xp: jp.interp(x,xp,fp), 0, ref_clip[key][n],x=t_interp,xp=t))
        print(key, ref_clip_interp[key][n].shape)

In [None]:
ref_clip_interp['position'][5]

In [None]:
for key in ref_clip.keys():
    ref_clip_interp[key] = jp.stack(ref_clip_interp[key],axis=0)
    print(key, ref_clip_interp[key].shape)


In [58]:
all_clips_reference = ReferenceClip()
all_clips_reference =all_clips_reference.replace(
    position=ref_clip_interp['position'],
    quaternion=ref_clip_interp['quaternion'],
    joints=ref_clip_interp['joints'],
    body_positions=ref_clip_interp['body_positions'],
    velocity=ref_clip_interp['velocity'],
    joints_velocity=ref_clip_interp['joints_velocity'],
    angular_velocity=ref_clip_interp['angular_velocity'],
    body_quaternions=ref_clip_interp['body_quaternions'],
)

In [59]:

reference_path = Path(cfg.paths.data_dir)/ "clips/all_clips_batch_interp.p"
with open(reference_path, "wb") as file:
    # Use pickle.dump() to save the data to the file
    pickle.dump(all_clips_reference, file)

In [172]:
reference_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/clips/all_clips_list_interp.p')
with open(reference_path, "wb") as file:
    # Use pickle.dump() to save the data to the file
    pickle.dump(all_clips, file)

In [20]:
all_ref_clip = {}
all_ref_clip['position'] = reference_clip.position
all_ref_clip['quaternion'] = reference_clip.quaternion
all_ref_clip['joints'] = reference_clip.joints
all_ref_clip['body_positions'] =reference_clip.body_positions
all_ref_clip['velocity'] = reference_clip.velocity
all_ref_clip['joints_velocity'] = reference_clip.joints_velocity
all_ref_clip['angular_velocity'] = reference_clip.angular_velocity
all_ref_clip['body_quaternions'] = reference_clip.body_quaternions

In [19]:
reference_path = Path(cfg.paths.data_dir)/ f"clips/all_clips_interp.p"
# reference_path = cfg.paths.data_dir/ f"clips/0.p"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)

In [None]:
np.cumsum([reference_clip.position[n].shape[0] for n in range(len(reference_clip.position))])

In [71]:
all_ref_clip = {}
all_ref_clip['position'] = jp.concatenate([clip for clip in reference_clip.position],axis=0)
all_ref_clip['quaternion'] = jp.concatenate([clip for clip in reference_clip.quaternion],axis=0)
all_ref_clip['joints'] = jp.concatenate([clip for clip in reference_clip.joints],axis=0)
all_ref_clip['body_positions'] = jp.concatenate([clip for clip in reference_clip.body_positions],axis=0)
all_ref_clip['velocity'] = jp.concatenate([clip for clip in reference_clip.velocity],axis=0)
all_ref_clip['joints_velocity'] = jp.concatenate([clip for clip in reference_clip.joints_velocity],axis=0)
all_ref_clip['angular_velocity'] = jp.concatenate([clip for clip in reference_clip.angular_velocity],axis=0)
all_ref_clip['body_quaternions'] = jp.concatenate([clip for clip in reference_clip.body_quaternions],axis=0)

In [None]:
all_clips = []
for n in range(len(all_ref_clip['position'])):
    temp_clip = ReferenceClip()
    temp_clip = temp_clip.replace(
        position = all_ref_clip['position'][n],
        quaternion = all_ref_clip['quaternion'][n],
        joints = all_ref_clip['joints'][n],
        body_positions = all_ref_clip['body_positions'][n],
        velocity = all_ref_clip['velocity'][n],
        joints_velocity = all_ref_clip['joints_velocity'][n],
        angular_velocity = all_ref_clip['angular_velocity'][n],
        body_quaternions = all_ref_clip['body_quaternions'][n]
    )
    all_clips.append(temp_clip)

In [None]:
reference_clip.position[5]

In [None]:
reference_path = Path(cfg.paths.data_dir)/ f"clips/all_clips_interp.p"
# reference_path = cfg.paths.data_dir/ f"clips/0.p"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)

In [15]:
clip_lengths = jp.array([reference_clip.position[n].shape[0] for n in range(len(reference_clip.position))])
clip_start_inds = np.concatenate(([0],np.cumsum(clip_lengths)[:-1]),axis=0)

In [None]:
clip_lengths

In [146]:
reference_clip = reference_clip.replace(
    position = all_ref_clip['position'],
    quaternion = all_ref_clip['quaternion'],
    joints = all_ref_clip['joints'],
    body_positions = all_ref_clip['body_positions'],
    velocity = all_ref_clip['velocity'],
    joints_velocity = all_ref_clip['joints_velocity'],
    angular_velocity = all_ref_clip['angular_velocity'],
    body_quaternions = all_ref_clip['body_quaternions'],
)

## Treadmill data

In [33]:
import pandas as pd
base_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/')

tredmill_data = pd.read_csv(base_path/'wt_berlin_linear_treadmill_dataset.csv')
kp_names = ['head', 'thorax', 'abdomen', 'r1', 'r2', 'r3', 'l1', 'l2', 'l3']
coords = ['_x', '_y', '_z']
df_names = [kp+coord for kp in kp_names for coord in coords]
kp_data_all = tredmill_data[df_names].values

sorted_kp_names = kp_names
kp_data = .3*kp_data_all.copy().reshape(1800,-1,27)
belt_speed = tredmill_data['belt speed (mm/s)'].values.reshape(1800,-1)

In [21]:
stac_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/transform_treadmill.p')
# stac_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/transform_mocap_fly_freejnt.p')
# Load mocap data from a file.
with open(stac_path, "rb") as file:
    d = pickle.load(file)
    mocap_qpos = jp.array(d["qpos"])

In [None]:
mocap_qpos.reshape(581,1800,-1).shape

In [None]:
1045800/1800

In [10]:
fly_skel = ((0,1),(1,2),(1,3),(1,4),(1,6),(1,7),(1,8))


In [29]:
kp_data = kp_data.reshape(1800,581,3,-1)

In [None]:
joint_idx = np.array([root.joint(i).id for i in range(root.njnt)])[1:]
end_eff = [
'claw_T1_left',
'claw_T1_right',
'claw_T2_left',
'claw_T2_right',
'claw_T3_left',
'claw_T3_right',
]
endeff_idxs = jp.array(
    [
        mujoco.mj_name2id(root, mujoco.mju_str2Type("body"), body)
        for body in end_eff
    ]
)

In [100]:
model_path = ("/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_fastviz.xml")

spec = mujoco.MjSpec()
spec.from_file(model_path)
thorax0 = spec.find_body("thorax")
mj_model = spec.compile()
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = env.sys.mj_model.opt.timestep
mj_data = mujoco.MjData(mj_model)

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]

# save rendering and log to wandb
mujoco.mj_kinematics(mj_model, mj_data)
# renderer = mujoco.Renderer(mj_model, height=512, width=512)
qposes_rollout = mocap_qpos.reshape(581,1800,-1)[0,:500]
qposes_rollout = qposes_rollout.at[:,2].set(.05)
# qposes_rollout = qposes_rollout.at[:,3:7].set(0)
frames = []
# render while stepping using mujoco
with mujoco.Renderer(mj_model, height=512, width=512) as renderer:
    for qpos1 in qposes_rollout:
        mj_data.qpos = qpos1
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera='track1', scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)

In [101]:
media.show_video(frames, fps=10)

0
This browser does not support the video tag.


In [None]:
t0 = kp_data[0,0,:,:]

for n in range(7):
    plt.plot(t0[0,fly_skel[n]],t0[1,fly_skel[n]])
    plt.scatter(t0[0,fly_skel[n]],t0[1,fly_skel[n]])

In [None]:
for n in range(7):
    plt.plot(kp_data[0,0,0,fly_skel[n,0]],kp_data[0,0,1,fly_skel[n,1]], '.')

# Amputation data

In [34]:
import pandas as pd
base_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/')
data_path = base_path / 'Amputation/09302024_fly2_0 R1C1_pose-3d.csv'
df = pd.read_csv(data_path)

In [35]:
df.shape

(1500, 247)

In [36]:

legs_data = ['R1', 'L2','R2', 'L3','R3']
joints_data = ['A','B','C','D','E']
coords_data = ['_x','_y','_z']
joint_pos_columns = [leg + joint + coord 
                     for leg in legs_data
                     for joint in joints_data 
                     for coord in coords_data]
all_cols = (joint_pos_columns
+['abdomen-tip_x',
'abdomen-tip_y',
'abdomen-tip_z',
'stripe-3_x',
'stripe-3_y',
'stripe-3_z',
'stripe-1_x',
'stripe-1_y',
'stripe-1_z',
'thorax-abdomen_x',
'thorax-abdomen_y',
'thorax-abdomen_z',
'head-thorax_x',
'head-thorax_y',
'head-thorax_z',])

kp_data = df[all_cols].values.reshape(1500,-1, 3)

In [46]:
model_path = ("/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_fastviz.xml")
spec = mujoco.MjSpec()
spec = spec.from_file(model_path)
mj_model = spec.compile()
init_site_idx = mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("site"), "tracking[coxa_T1_left]"),
data = mujoco.MjData(mj_model)
mujoco.mj_forward(mj_model, data)
# mj_model.find_body("thorax").id
# frame0 += site_pos[0, :]  # Body-coxa T1 left joint is the data origin.


array([ 8.66499973e-02,  3.81000000e-07, -3.48224356e-02])

In [47]:
def transform_frame(frame):
    """Transform a single frame from data to model reference frame."""
    # Rotate around z-axis.
    frame = frame[:, [1, 0, 2]]
    frame[:, 1] *= -1
    # Change units mm to cm.
    # frame *= 0.1
    return frame
kp_transform = np.array([transform_frame(frame)+data.xpos[init_site_idx] for frame in kp_data])

In [48]:
data_dict = {'kp_data':kp_data, 'kp_transform':kp_transform,
             'kp_names':all_cols}

In [39]:
all_cols

['R1A_x',
 'R1A_y',
 'R1A_z',
 'R1B_x',
 'R1B_y',
 'R1B_z',
 'R1C_x',
 'R1C_y',
 'R1C_z',
 'R1D_x',
 'R1D_y',
 'R1D_z',
 'R1E_x',
 'R1E_y',
 'R1E_z',
 'L2A_x',
 'L2A_y',
 'L2A_z',
 'L2B_x',
 'L2B_y',
 'L2B_z',
 'L2C_x',
 'L2C_y',
 'L2C_z',
 'L2D_x',
 'L2D_y',
 'L2D_z',
 'L2E_x',
 'L2E_y',
 'L2E_z',
 'R2A_x',
 'R2A_y',
 'R2A_z',
 'R2B_x',
 'R2B_y',
 'R2B_z',
 'R2C_x',
 'R2C_y',
 'R2C_z',
 'R2D_x',
 'R2D_y',
 'R2D_z',
 'R2E_x',
 'R2E_y',
 'R2E_z',
 'L3A_x',
 'L3A_y',
 'L3A_z',
 'L3B_x',
 'L3B_y',
 'L3B_z',
 'L3C_x',
 'L3C_y',
 'L3C_z',
 'L3D_x',
 'L3D_y',
 'L3D_z',
 'L3E_x',
 'L3E_y',
 'L3E_z',
 'R3A_x',
 'R3A_y',
 'R3A_z',
 'R3B_x',
 'R3B_y',
 'R3B_z',
 'R3C_x',
 'R3C_y',
 'R3C_z',
 'R3D_x',
 'R3D_y',
 'R3D_z',
 'R3E_x',
 'R3E_y',
 'R3E_z',
 'abdomen-tip_x',
 'abdomen-tip_y',
 'abdomen-tip_z',
 'stripe-3_x',
 'stripe-3_y',
 'stripe-3_z',
 'stripe-1_x',
 'stripe-1_y',
 'stripe-1_z',
 'thorax-abdomen_x',
 'thorax-abdomen_y',
 'thorax-abdomen_z',
 'head-thorax_x',
 'head-thorax_y',
 'head-

In [49]:
ioh5.save(base_path / 'Amputation/kp_data_amp.h5', data_dict)

In [50]:
kp_data.shape

(1500, 30, 3)

In [17]:
len(joint_pos_columns)/3

25.0

In [None]:
'abdomen-tip_x',
'abdomen-tip_y',
'abdomen-tip_z',
'stripe-3_x',
'stripe-3_y',
'stripe-3_z',
'stripe-1_x',
'stripe-1_y',
'stripe-1_z',
'thorax-abdomen_x',
'thorax-abdomen_y',
'thorax-abdomen_z',
'head-thorax_x',
'head-thorax_y',
'head-thorax_z',
'l-eye-t_x',
'l-eye-t_y',
'l-eye-t_z',
'l-eye-b_x',
'l-eye-b_y',
'l-eye-b_z',
'r-eye-t_x',
'r-eye-t_y',
'r-eye-t_z',
'r-eye-b_x',
'r-eye-b_y',
'r-eye-b_z',