# Imports set up rendering

In [None]:
#@title Check if MuJoCo installation was successful

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
# %env MUJOCO_GL=egl
os.environ['MUJOCO_GL'] = 'egl'
os.environ['PYOPENGL_PLATFORM'] = 'egl'
try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=True "
    # "--xla_gpu_enable_async_collectives=true "
    # "--xla_gpu_enable_latency_hiding_scheduler=true "
    # "--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ['CUDA_VISIBLE_DEVICES'] = '1' # Use GPU 1

In [None]:
%load_ext autoreload
%autoreload 2
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use GPU 1
import functools
import jax
# jax.config.update("jax_enable_x64", True)

n_gpus = jax.device_count(backend="gpu")
print(f"Using {n_gpus} GPUs")
from typing import Dict
from brax import envs
import mujoco
import pickle
import warnings
import mediapy as media
import hydra
import jax.numpy as jp

from omegaconf import DictConfig, OmegaConf
from brax.training.agents.ppo import networks as ppo_networks
from custom_brax import custom_ppo as ppo
from custom_brax import custom_wrappers
from custom_brax import custom_ppo_networks
from orbax import checkpoint as ocp
from flax.training import orbax_utils
from preprocessing.mjx_preprocess import process_clip_to_train
from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking
from utils.utils import *
from utils.fly_logging import log_eval_rollout
from tqdm import tqdm
warnings.filterwarnings("ignore", category=DeprecationWarning)
# jax.config.update("jax_enable_x64", True)

from hydra import initialize, compose
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra


# Load configs

In [None]:
base_dir ='/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt'
run_cfg_list = sorted(list(Path(base_dir).rglob('run_config.yaml')))
for n, run_cfg in enumerate(run_cfg_list):
    print(n, run_cfg)


cfg_num = -1
cfg = OmegaConf.load(run_cfg_list[cfg_num])
run_id = int(run_cfg_list[cfg_num].parent.parent.stem.split('=')[1])
print(cfg.dataset.dname)
fig_dir = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/debug/figures')

In [None]:
dataset = cfg.dataset.dname
with initialize(version_base=None, config_path="configs"):
    cfg_temp=compose(config_name='config.yaml',overrides= [f"dataset={dataset}", f"train=train_{dataset}", "paths=walle", "version=ckpt", f'run_id={run_id}'],return_hydra_config=True,)
    HydraConfig.instance().set_config(cfg_temp)
    
cfg.paths = cfg_temp.paths

In [None]:
for k in cfg.paths.keys():
    if (k != 'user'):
        cfg.paths[k] = Path(cfg.paths[k])
        cfg.paths[k].mkdir(parents=True, exist_ok=True)
env_cfg = cfg.dataset
env_args = cfg.dataset.env_args

cfg.paths.base_dir = cfg.paths.base_dir.parent / 'ckpt'
reference_path = cfg.paths.data_dir/ f"clips/all_clips_batch_interp.p"
# reference_path = cfg.paths.data_dir/ f"clips/{cfg.dataset['clip_idx']}"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)
ref_data = np.concatenate([reference_clip.position,reference_clip.quaternion,reference_clip.joints], axis=-1)

# Load env

In [None]:
from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking, FlyRunSim, _bounded_quat_dist
# dataset = 'multiclip'

# with initialize(version_base=None, config_path="configs"):
#     cfg=compose(config_name='config.yaml',overrides= [f"dataset=fly_{dataset}", f"train=train_fly_{dataset}", "paths=walle"],return_hydra_config=True,)
#     HydraConfig.instance().set_config(cfg)


# env_args = cfg.dataset.env_args
envs.register_environment("fly_freejnt_clip", FlyTracking)
envs.register_environment("fly_freejnt_multiclip", FlyMultiClipTracking)
envs.register_environment("fly_run_policy", FlyRunSim)
# cfg.dataset.env_args.mjcf_path = '/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_viz_only.xml'
# cfg.dataset.env_args.mjcf_path = '/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fast.xml'
print(cfg.train.env_name)
cfg.dataset.env_args.iterations = 12
cfg.dataset.env_args.ls_iterations = 12
env = envs.get_environment(
    cfg.train.env_name,
    reference_clip=reference_clip,
    **cfg.dataset.env_args,
)

In [None]:
rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
    
# rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
# define the jit reset/step functions
jit_reset = jax.jit(rollout_env.reset)
jit_step = jax.jit(rollout_env.step)
state = jit_reset(jax.random.PRNGKey(0))


In [None]:
plt.plot(reference_clip.position[0,:,:])

In [None]:
policy_data = ioh5.load('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/Torque_control_multiclip.h5')


In [None]:
spec = mujoco.MjSpec()
spec = spec.from_file(cfg.dataset.rendering_mjcf)
mj_model = spec.compile()
print(cfg.dataset.rendering_mjcf)
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = env.sys.mj_model.opt.timestep

mj_data = mujoco.MjData(mj_model)
site_id = [
    mj_model.site(i).id
    for i in range(mj_model.nsite)
    if "-0" in mj_model.site(i).name
]
for id in site_id:
    mj_model.site(id).rgba = [1, 0, 0, 1]

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

# save rendering and log to wandb
os.environ["MUJOCO_GL"] = "osmesa"
mujoco.mj_kinematics(mj_model, mj_data)
# renderer = mujoco.Renderer(mj_model, height=512, width=512)

frames = []
clip_idx = 0
qposes_ref = ref_data[clip_idx]
# qposes_rollout = np.concatenate([all_clips_reference.position,all_clips_reference.quaternion,all_clips_reference.joints], axis=-1)[clip_idx]
qposes_rollout = policy_data['qposes'][clip_idx]
# render while stepping using mujoco
with mujoco.Renderer(mj_model, height=480, width=480) as renderer:
    for qpos1, qpos2 in zip(qposes_rollout, qposes_ref):
        mj_data.qpos = np.append(qpos1, qpos2)
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera=1, scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)


In [None]:
media.show_video(frames, fps=50)

In [None]:
qposes_rollout[10].copy(), ctrl_all[10]

In [None]:

model_path = "/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_force_fastviz.xml"
spec = mujoco.MjSpec()
spec = spec.from_file(model_path)
# first_joint = thorax.first_joint()
# first_joint.delete()
root = spec.compile()
root.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}[cfg.dataset.env_args.solver.lower()]
root.opt.iterations = env_args.iterations
root.opt.ls_iterations = env_args.ls_iterations
root.opt.timestep = env_args.physics_timestep
root.opt.jacobian = 0
data = mujoco.MjData(root)
mujoco.mj_forward(root, data)


n_frames = 1000
height = 512
width = 512
frames = []
fps = 1/.002

times = []
sensordata = []
scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
clip_idx=1
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
qpos_all,rollout,ncon_all = [],[],[]
# ref_data = np.concatenate([all_clips_reference.position,all_clips_reference.quaternion,all_clips_reference.joints], axis=-1)
ctrl_all = policy_data['ctrl'][clip_idx]
with mujoco.Renderer(root, height, width) as renderer:
    for t in range(n_frames):
        # data.qpos = ref_data[clip_idx,t]
        data.ctrl = ctrl_all[t].copy()
        data.qpos = qposes_rollout[t].copy()
        # # data.qpos = np.concatenate([reference_clip.position[clip_idx][t],reference_clip.quaternion[clip_idx][t],reference_clip.joints[clip_idx][t]])
        while data.time < t/fps:
            mujoco.mj_step(root, data)
            sensordata.append(data.sensordata.copy())
        times.append(data.time)
        renderer.update_scene(data,camera='track2',scene_option=scene_option)
        frame = renderer.render()
        frames.append(frame)
        qpos_all.append(data.qpos.copy())
        ncon_all.append(data.ncon)
        rollout.append(data)

media.show_video(frames, fps=50)


In [None]:
_pos_reward_weight = 1
_joint_reward_weight = 1
_angvel_reward_weight = 1
_bodypos_reward_weight = 1
_endeff_reward_weight = 1
_quat_reward_weight = 1
_pos_scaling = 400.0
_joint_scaling = 0.25
_angvel_scaling = 0.5
_bodypos_scaling = 0.5
_endeff_scaling = 0.05
_quat_scaling = 4.0
clip_idx=1
rewards = {'pos':[],'joint':[], 'angvel':[], 'bodypos':[], 'endeff':[],'pos_reward':[], 'joint_reward':[], 'angvel_reward':[], 'bodypos_reward':[], 'endeff_reward':[], 'quat_distance':[], 'quat_reward':[]}
for cur_frame in range (n_frames): 
    data= rollout[cur_frame]
    
    quat_track = reference_clip.position[clip_idx,cur_frame]
    pos_distance = jp.sum((data.qpos[:3] - quat_track)**2)
    pos_reward = _pos_reward_weight * jp.exp(-_pos_scaling * pos_distance)
    
    quat_track = reference_clip.quaternion[clip_idx,cur_frame]
    quat_distance = jp.sum(_bounded_quat_dist(data.qpos[3:7], quat_track) ** 2)
    quat_reward = _quat_reward_weight * jp.exp(-_quat_scaling * quat_distance)
    
    joint_track = reference_clip.joints[clip_idx,cur_frame]
    joint_distance = jp.sum((data.qpos[7:] - joint_track) ** 2)
    joint_reward = _joint_reward_weight * jp.exp(-_joint_scaling * joint_distance)
    
    angvel_track = reference_clip.angular_velocity[clip_idx,cur_frame]
    angvel_distance = jp.sum((data.qvel[3:6] - angvel_track) ** 2)
    angvel_reward = _angvel_reward_weight * jp.exp(-_angvel_scaling * angvel_distance)
    
    bodypos_track = reference_clip.body_positions[clip_idx,cur_frame]
    bodypos_distance = jp.sum((data.xpos[env._body_idxs]- bodypos_track[env._body_idxs]).flatten()** 2)
    bodypos_reward = _bodypos_reward_weight * jp.exp(-_bodypos_scaling* bodypos_distance)
    
    endeff_track = reference_clip.body_positions[clip_idx,cur_frame]
    endeff_distance = jp.sum((data.xpos[env._endeff_idxs]- endeff_track[env._endeff_idxs]).flatten()** 2)
    endeff_reward = _endeff_reward_weight * jp.exp(-_endeff_scaling* endeff_distance)

        
    rewards['pos'].append(pos_distance)
    rewards['joint'].append(joint_distance)
    rewards['angvel'].append(angvel_distance)
    rewards['bodypos'].append(bodypos_distance)
    rewards['endeff'].append(endeff_distance)
    rewards['quat_distance'].append(quat_distance)
    
    rewards['pos_reward'].append(pos_reward)
    rewards['joint_reward'].append(joint_reward)
    rewards['angvel_reward'].append(angvel_reward)
    rewards['bodypos_reward'].append(bodypos_reward)
    rewards['endeff_reward'].append(endeff_reward)
    rewards['quat_reward'].append(quat_reward)
    


In [None]:

fig, axs = plt.subplots(2,1, figsize=(10,10))
ax = axs[0]
# ax.plot(rewards['pos'], label='pos')
# ax.plot(rewards['joint'], label='joint')
# ax.plot(rewards['angvel'], label='angvel')
# ax.plot(rewards['bodypos'], label='bodypos')
ax.plot(rewards['endeff'], label='endeff')
ax.plot(rewards['quat_distance'], label='quat_distance')
ax.legend()

ax = axs[1]
ax.plot(rewards['pos_reward'], label='pos_reward')
ax.plot(rewards['joint_reward'], label='joint_reward')
# ax.plot(rewards['angvel_reward'], label='angvel_reward')
# ax.plot(rewards['bodypos_reward'], label='bodypos_reward')
ax.plot(rewards['endeff_reward'], label='endeff_reward')
ax.plot(rewards['quat_reward'], label='quat_reward')
ax.legend()
ax.set_ylim(-.1,1.1)



In [None]:
joint_distance

In [None]:
plt.plot(ncon_all)

In [None]:
end_eff = [
'claw_T1_left',
'claw_T1_right',
'claw_T2_left',
'claw_T2_right',
'claw_T3_left',
'claw_T3_right',
]
N = 10
sdata = (np.stack(sensordata).reshape(-1,6,3)) # Time x end_eff x xyz, x=forward
sdata = 10* np.apply_along_axis(lambda m: np.convolve(m, np.ones(N)/N, mode='full'), axis=0, arr=sdata)

fig, axs = plt.subplots(3, 2, figsize=(10, 10), sharey=True)
axs = axs.flatten()
for n in range(len(end_eff)):
    ax = axs[n]
    ax.plot(sdata[:,n,0])
    ax.plot(sdata[:,n,1])
    ax.plot(sdata[:,n,2])
# plt.plot(sdata[:,:,2])

# Cleaning Data

In [None]:
%matplotlib widget

In [None]:

spec = mujoco.MjSpec()
spec = spec.from_file(cfg.dataset.env_args.mjcf_path)
mj_model = spec.compile()

mj_model.opt.solver = {
"cg": mujoco.mjtSolver.mjSOL_CG,
"newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = cfg.dataset.env_args.physics_timestep

mj_data = mujoco.MjData(mj_model)

site_names = [
'tracking[coxa_T1_left]',
'tracking[femur_T1_left]',
'tracking[tibia_T1_left]',
'tracking[tarsus_T1_left]',
'tracking[claw_T1_left]',
'tracking[coxa_T1_right]',
'tracking[femur_T1_right]',
'tracking[tibia_T1_right]',
'tracking[tarsus_T1_right]',
'tracking[claw_T1_right]',
'tracking[coxa_T2_left]',
'tracking[femur_T2_left]',
'tracking[tibia_T2_left]',
'tracking[tarsus_T2_left]',
'tracking[claw_T2_left]',
'tracking[coxa_T2_right]',
'tracking[femur_T2_right]',
'tracking[tibia_T2_right]',
'tracking[tarsus_T2_right]',
'tracking[claw_T2_right]',
'tracking[coxa_T3_left]',
'tracking[femur_T3_left]',
'tracking[tibia_T3_left]',
'tracking[tarsus_T3_left]',
'tracking[claw_T3_left]',
'tracking[coxa_T3_right]',
'tracking[femur_T3_right]',
'tracking[tibia_T3_right]',
'tracking[tarsus_T3_right]',
'tracking[claw_T3_right]',
]

site_idxs = [
    mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("site"), site)
    for site in site_names
]
mujoco.mj_forward(mj_model, mj_data)
ref_bout = mj_data.site_xpos[site_idxs]
ref_bout = ref_bout - ref_bout[0]

    
def procrustes_jax(source_points, target_points, floor=None):
    """
    Perform Procrustes analysis to align matrix source_points to matrix target_points.
    
    Parameters:
    source_points : (N, M) array_like
        The source input matrix.
    target_points : (N, M) array_like
        The reference/target input matrix.
    kp_idxs: (N,) array_like
        The indices of the keypoints to use for the alignment.
    floor : float, optional
    
    Returns:
    d : float
        The residual sum of squared errors.
    Z : (T,N, M) ndarray
        The matrix Y, transformed to best match X.
    R : (M, M) ndarray
        The rotation applied to Y.
    scale : float
        The scaling factor applied to Y.
    """
    
    # Center both matrices
    source_centered = source_points - jp.mean(source_points, axis=0)
    target_centered = target_points - jp.mean(target_points, axis=0)

    # Normalize the Frobenius norm to 1
    target_norm = jp.linalg.norm(target_centered)
    source_norm = jp.linalg.norm(source_centered)
    target_centered /= target_norm
    source_centered /= source_norm

    # Compute the optimal rotation matrix using Singular Value Decomposition (SVD) for the first frame
    U, _, Vt = jp.linalg.svd(jp.dot(target_centered.T, source_centered).T)
    R = jp.dot(U, Vt)

    # Apply the rotation and scaling to all frames
    scale = target_norm / source_norm
    Z = jp.dot(source_points, R) * scale
    d = jp.sum(jp.square(target_centered - Z))

    return d, Z, R, scale

def align_groundplane(Z0, percentile=10, floor_height=-0.13):
    """
    Align the keypoints to the groundplane plane by rotating and translating the keypoints.

    Args:
        Z0 (_type_): _description_

    Returns:
        _type_: _description_
    """
    # Compute the optimal rotation matrix for all frames of endeff to groundplane
    Z1 = Z0[:,:,2].copy()
    bot_10 = jp.percentile(Z1, percentile, axis=0)
    bot_t = Z0[Z1 < bot_10]
    target_points = bot_t.copy()
    groundplane = target_points.at[:,2].set(floor_height)

    source_points = bot_t.copy()
    target_points = groundplane.copy()
    source_centered = source_points - jp.mean(source_points, axis=0)
    target_centered = target_points - jp.mean(target_points, axis=0)

    # Compute the optimal rotation matrix for all frames of endeff to groundplan
    H = jp.dot(source_centered.T, target_centered)
    U, _, Vt = jp.linalg.svd(H.T)
    R2 = jp.dot(U, Vt)

    # Ensure a proper rotation (det(R) = 1, avoiding reflection)
    if jp.linalg.det(R2) < 0:
        Vt = jp.array(Vt)
        Vt = Vt.at[-1, :].set(-1*Vt[-1, :])
        R2 = jp.dot(U, Vt)
    t = jp.mean(target_points, axis=0) - jp.dot(R2, jp.mean(source_points, axis=0))

    # Apply the rotation and translation to all frames, all kp
    Z2 = jp.dot(Z0, R2) + t
    return Z2, R2, t, bot_t, groundplane


def plot_plane_and_points(points, fig=None, ax=None, color='r', label=''):
    """
    Plot the plane and the 3D points.
    
    Args:
        points (jax.numpy.ndarray): Array of shape (N, 3), the 3D points.
    """
    print('points',points.shape)
    
    # Compute the centroid of the points
    centroid = jp.mean(points, axis=0)
    
    # Center the points around the centroid
    centered_points = points - centroid
    
    # Compute the SVD of the centered points
    _, _, vh = jp.linalg.svd(centered_points, full_matrices=False)
    
    # The plane's normal is the singular vector corresponding to the smallest singular value
    normal = vh[-1]
    
    # Compute the intercept term (d) for the plane equation ax + by + cz + d = 0
    d = -jp.dot(normal, centroid)
    
    # Create a mesh grid for the plane
    x = jp.linspace(points[:, 0].min(), points[:, 0].max(), 10)
    y = jp.linspace(points[:, 1].min(), points[:, 1].max(), 10)
    x, y = jp.meshgrid(x, y)
    
    # Compute the corresponding z values for the plane equation ax + by + cz + d = 0
    z = (-normal[0] * x - normal[1] * y - d) / normal[2]
    
    # Plot the points and the plane
    if fig is None and ax is None:
        fig = plt.figure(figsize=(10, 7))
        ax = fig.add_subplot(111, projection='3d')
    
    # Plot the points
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=color, label=label+' Points')
    
    # Plot the plane
    ax.plot_surface(x, y, z, alpha=0.5, color=color, label=label + ' Plane')
    # ax.plot_surface(x, y, -.13*jp.ones_like(x), alpha=0.5, color='black', label='Ground Plane')
    
    # Set labels and title
    ax.set_xlabel("X-axis")
    ax.set_ylabel("Y-axis")
    ax.set_zlabel("Z-axis")
    ax.set_title("Fitted Plane and 3D Points")
    ax.legend()
    

# procrustes_partial = functools.partial(procrustes_jax)
jit_procrustes = jax.jit(procrustes_jax)
vmap_procrustes_partial = jax.vmap(jit_procrustes, (None,0))


In [None]:
import pandas as pd
from scipy.spatial import procrustes
base_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/')
data_path = base_path / 'combined_wt_berlin_walking_v3.pq'
full_df = pd.read_parquet(data_path, engine='pyarrow')
bout_stats = full_df.groupby(['walking_bout_number','fullfile','Sex'])[['fictrac_delta_rot_lab_y_mms', 'fictrac_delta_rot_lab_z_deg/s']].agg(['mean','min','max','std','count'])
fast = (bout_stats[('fictrac_delta_rot_lab_y_mms','mean')] >= 12) & (bout_stats[('fictrac_delta_rot_lab_y_mms','min')] >= 10)
straight = (bout_stats[('fictrac_delta_rot_lab_z_deg/s','mean')].abs() <= 45) &\
           (bout_stats[('fictrac_delta_rot_lab_z_deg/s','min')] >= -60) &\
           (bout_stats[('fictrac_delta_rot_lab_z_deg/s','max')] <= 60)

legs = ['T1_left', 'T1_right', 'T2_left', 'T2_right', 'T3_left', 'T3_right']
joints = ['coxa', 'femur', 'tibia', 'tarsus']
xpos_geoms = ['coxa', 'femur', 'tibia', 'tarsus', 'claw']
joint_names = [f'{joint}_{leg}' for leg in legs for joint in joints]
xpos_names = [f'{joint}_{leg}' for leg in legs for joint in xpos_geoms]
# physics.named.data.framepos[pos_names]
site_names = [f'tracking[{joint_name}]' for joint_name in xpos_names]

legs_data = ['L1', 'R1', 'L2','R2', 'L3','R3']
joints_data = ['A','B','C','D','E']
coords_data = ['_x','_y','_z']
joint_pos_columns = [leg + joint + coord 
                     for leg in legs_data
                     for joint in joints_data 
                     for coord in coords_data]

def transform_bout(bout):
    """Transform a single frame from data to model reference frame."""
    # Rotate around z-axis.
    bout = bout[:, :, [1, 0, 2]]
    bout[:,:, 1] *= -1
    # bout = bout - bout[:,:1,:]
    # Change units mm to cm.
    bout *= 0.1
    return bout
coxa_kp_idxs = np.arange(0,30,5)
endeff_kp_idxs = np.arange(4,30,5)
straight_bouts = bout_stats[fast & straight].index
straight_bout_num = np.array([i[0] for i in straight_bouts],dtype=int)
mean_straight_bout = (full_df[full_df['walking_bout_number'] == 14574][joint_pos_columns].values.reshape(-1,30, 3)).mean(axis=0)
# ref_bout = transform_bout(full_df[full_df['walking_bout_number'] == 14574][joint_pos_columns].values.reshape(-1,30, 3))
all_bout_nums = full_df['walking_bout_number'].unique()
bout_dict = {'walking_bout{:04}'.format(n):{} for n in range(len(all_bout_nums))}
for n, bout_num in enumerate(tqdm(all_bout_nums)):
    bout = full_df[full_df['walking_bout_number'] == bout_num]
    kp_data = transform_bout(bout[joint_pos_columns].values.reshape(-1,30, 3))
    # kp_data = np.array([procrustes(ref_bout, kp_data[i])[1] for i in range(kp_data.shape[0])])
    # kp_data = np.array(vmap_procrustes_partial(ref_bout,kp_data))
    kp_data = kp_data - kp_data[:,:1,:]
    d, _, R, scale  = procrustes_jax(jp.mean(kp_data,axis=0),ref_bout[:])
    Z = jp.dot(kp_data, R) * scale
    Z0 = Z[:,endeff_kp_idxs,:]
    _, R2, t, bot_t, Ground = align_groundplane(Z0,percentile=5)
    Z2 = jp.dot(Z, R2.T) + t
    bout_dict['walking_bout{:04}'.format(n)]['orig_xpos'] = np.array(Z2).copy()
    # bout_dict['walking_bout{:04}'.format(n)]['orig_xpos'] = np.array(kp_data).copy()
# ioh5.save('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/all_walking_bouts.h5', bout_dict)

In [None]:
# coxa_kp_idxs = np.arange(0,30,5)
# endeff_kp_idxs = np.arange(4,30,5)
# n = 1729
# kp_data = bout_dict['walking_bout{:04}'.format(n)]['orig_xpos'].copy()
# d, _, R, scale  = procrustes_jax(jp.mean(kp_data,axis=0),ref_bout[:])
# Z = jp.dot(kp_data, R.T) * scale
# Z0 = Z[:,endeff_kp_idxs,:]
# _, R2, t, bot_t, Ground = align_groundplane(Z0,percentile=10)
# Z2 = jp.dot(Z, R2.T) + t

In [None]:
plt.close('all')

In [None]:
import matplotlib as mpl
def map_discrete_cbar(cmap,N):
    cmap = plt.get_cmap(cmap,N+1)
    bounds = np.arange(-.5,N+1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    return cmap, norm
cmap,norm = map_discrete_cbar('turbo',6)
fly_skel = np.array([(n,n+1) for n in range(29)])
fly_skel = np.stack((fly_skel[:4],fly_skel[5:9],fly_skel[10:14],fly_skel[15:19],fly_skel[20:24],fly_skel[25:30],))

source_points = kp_data[0,coxa_kp_idxs]
target_points = ref_bout[coxa_kp_idxs]

fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
plot_plane_and_points(source_points.reshape(-1,3), fig, ax, color='b', label='Source')
plot_plane_and_points(target_points.reshape(-1,3), fig, ax, color='k', label='Y0')
plot_plane_and_points(Z[:,coxa_kp_idxs].reshape(-1,3), fig, ax, color='r', label='Z')

# ax.scatter(ref_bout[::5, 0], ref_bout[::5, 1],ref_bout[::5, 2], c=np.arange(ref_bout[::5].shape[0]),cmap=cmap)
# ax.scatter(Z[::5, 0], Z[::5, 1],Z[::5, 2], c=np.arange(Z[::5].shape[0]),cmap=cmap)
# ax.scatter(Z[:,4::5, 0], Z[:,4::5, 1],Z[:,4::5, 2],c='r')
for n in range(fly_skel.shape[0]):
    ax.plot(ref_bout[fly_skel[n], 0], ref_bout[fly_skel[n], 1],  ref_bout[fly_skel[n], 2],c='k')
    ax.plot(np.mean(Z[:,fly_skel[n], 0],axis=0), np.mean(Z[:,fly_skel[n], 1],axis=0),np.mean(Z[:,fly_skel[n], 2],axis=0),c='r')

In [None]:
Z2b = jp.dot(bot_t, R2.T) + t
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
# plot_plane_and_points(bot_t.reshape(-1,3), fig, ax, color='b', label='Z')
plot_plane_and_points(Ground.reshape(-1,3), fig, ax, color='k', label='ground')
plot_plane_and_points(Z2b[:,:].reshape(-1,3), fig, ax, color='r', label='Z2')
plot_plane_and_points(Z2[:,endeff_kp_idxs].reshape(-1,3), fig, ax, color='r', label='Z2')

for n in range(fly_skel.shape[0]):
    ax.plot(ref_bout[fly_skel[n], 0], ref_bout[fly_skel[n], 1],  ref_bout[fly_skel[n], 2],c='k')
    ax.plot(np.mean(Z2[:,fly_skel[n], 0],axis=0), np.mean(Z2[:,fly_skel[n], 1],axis=0),np.mean(Z[:,fly_skel[n], 2],axis=0),c='r')

In [None]:
# bout_dict2 = ioh5.load('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/all_walking_bouts.h5')
ioh5.save('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/all_walking_bouts.h5', bout_dict)

In [None]:
stac_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/transform_mocap_fly_all.p')
# Load mocap data from a file.
# with open(stac_path, "rb") as file:
#     d = pickle.load(file)
#     qposes = np.array(d["qpos"])
#     kp_data = np.array(d["kp_data"])
#     kp_names = d["kp_names"]
#     offsets = d["offsets"]
with open(stac_path, "rb") as file:
    d = pickle.load(file)
    mocap_qpos = np.array(d["qpos"])

In [None]:
from preprocessing.mjx_preprocess import process_clip, save_reference_clip_to_h5, load_reference_clip_from_h5, ReferenceClip
from jax import vmap
import mujoco.mjx as mjx

In [None]:
[col for col in bout.columns if 'fictrac' in col]

In [None]:
lin_vel_y_cm, lin_vel_x_cm, heading_deg = [], [], []
for n, bout_num in enumerate(all_bout_nums):
    bout = full_df[full_df['walking_bout_number'] == bout_num]
    lin_vel_y_cm.append(bout['fictrac_delta_rot_lab_y_mms'].values/10)
    lin_vel_x_cm.append(bout['fictrac_delta_rot_lab_x_mms'].values/10)
    heading_deg.append(bout['fictrac_heading_deg'].values)
# lin_vel_y_cm = np.concatenate(lin_vel_y_cm,axis=0)
clip_shape = np.array([clip.shape[0] for clip in lin_vel_y_cm])
clip_shape = np.concatenate([[0],clip_shape])
mocap_qpos_reshaped = []
for n in range(1,len(clip_shape)-1):
    mocap_qpos_reshaped.append(mocap_qpos[np.sum(clip_shape[:n]):np.sum(clip_shape[:n+1])])
mocap_qpos_reshaped.append(mocap_qpos[np.sum(clip_shape[:-1]):])
lin_vel_y_cm[-1] = lin_vel_y_cm[-1][:mocap_qpos_reshaped[-1].shape[0]]
lin_vel_x_cm[-1] = lin_vel_x_cm[-1][:mocap_qpos_reshaped[-1].shape[0]]

In [None]:
%matplotlib inline

In [None]:
fig, axs = plt.subplots(figsize=(3,3))
ax = axs
ax.plot(heading_deg[2])

In [None]:
for n in tqdm(range(len(mocap_qpos_reshaped))):
    mocap_qpos_reshaped[n][:,0] = np.cumsum(lin_vel_y_cm[n],axis=0)*(1/300)
    mocap_qpos_reshaped[n][:,1] = np.cumsum(lin_vel_x_cm[n],axis=0)*(1/300)

In [None]:
clip_idx = 0
for clip_idx in range(10):
    plt.plot(mocap_qpos_reshaped[clip_idx][:,0],mocap_qpos_reshaped[clip_idx][:,1])
    plt.scatter(mocap_qpos_reshaped[clip_idx][0,0],mocap_qpos_reshaped[clip_idx][0,1],c='g')
    plt.scatter(mocap_qpos_reshaped[clip_idx][-1,0],mocap_qpos_reshaped[clip_idx][-1,1],c='r')
# plt.axis([-1,1,-1,1]) 

In [None]:
mocap_qpos_interp = []
lin_vel_y_interp = []
lin_vel_x_interp = []
orig_t = []
for n in tqdm(range(len(mocap_qpos_reshaped))):
    clip_len = mocap_qpos_reshaped[n].shape[0]
    tmax = 1/300 * clip_len # 1/300 is original mocap hz
    t = jp.linspace(0,tmax,clip_len)
    t_interp = jp.linspace(0,tmax,1000)
    mocap_qpos_interp.append(jp.apply_along_axis(lambda fp,x,xp: jp.interp(x,xp,fp), 0, mocap_qpos_reshaped[n],x=t_interp,xp=t))
    # lin_vel_y_interp.append(jp.apply_along_axis(lambda fp,x,xp: jp.interp(x,xp,fp), 0, lin_vel_y_cm[n],x=t_interp,xp=t))
    # lin_vel_x_interp.append(jp.apply_along_axis(lambda fp,x,xp: jp.interp(x,xp,fp), 0, lin_vel_x_cm[n],x=t_interp,xp=t))
    orig_t.append(t)
    # print(key, ref_clip_interp[key][n].shape)
    
mocap_qpos_interp = np.stack(mocap_qpos_interp,axis=0)
# lin_vel_x_interp = np.stack(lin_vel_x_interp,axis=0)
# lin_vel_y_interp = np.stack(lin_vel_y_interp,axis=0)

In [None]:
mocap_qpos_interp.shape,lin_vel_x_interp.shape

In [None]:
mocap_qpos_interp[:,:,1] += np.cumsum(lin_vel_x_interp,axis=1)*cfg.dataset.env_args.physics_timestep
mocap_qpos_interp[:,:,0] += np.cumsum(lin_vel_y_interp,axis=1)*cfg.dataset.env_args.physics_timestep


In [None]:
clip_idx = 0
for clip_idx in range(100):
    plt.plot(mocap_qpos_interp[clip_idx,:,0],mocap_qpos_interp[clip_idx,:,1])
    plt.scatter(mocap_qpos_interp[clip_idx,0,0],mocap_qpos_interp[clip_idx,0,1],c='g')
    plt.scatter(mocap_qpos_interp[clip_idx,-1,0],mocap_qpos_interp[clip_idx,-1,1],c='r')
# plt.axis([-1,1,-1,1])

In [None]:
clip_length = 1000
start_step = 0

# jit the process_clip function 
jit_process_clip = jax.jit(process_clip)


# Reshape_qposes to have the batch dimension and vmap the jitted function
all_clips_qpos_interp = jp.reshape(mocap_qpos_interp, (-1, clip_length, mocap_qpos_interp.shape[-1]))
vmap_jit_process_clip = jax.vmap(jit_process_clip, in_axes=(0,None,None))
all_clips_qpos_interp.shape

In [None]:

spec = mujoco.MjSpec()
spec = spec.from_file(cfg.dataset.env_args.mjcf_path)
mj_model = spec.compile()

mj_model.opt.solver = {
"cg": mujoco.mjtSolver.mjSOL_CG,
"newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = cfg.dataset.env_args.physics_timestep

mj_data = mujoco.MjData(mj_model)

# Initialize MuJoCo model and data structures & place into GPU
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

all_clips_interp = vmap_jit_process_clip(all_clips_qpos_interp, mjx_model, mjx_data)

In [None]:
# #### Process clips #####

# spec = mujoco.MjSpec()
# spec = spec.from_file(cfg.dataset.env_args.mjcf_path)
# mj_model = spec.compile()

# mj_model.opt.solver = {
# "cg": mujoco.mjtSolver.mjSOL_CG,
# "newton": mujoco.mjtSolver.mjSOL_NEWTON,
# }["cg"]
# mj_model.opt.iterations = cfg.dataset.env_args.iterations
# mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
# mj_model.opt.timestep = cfg.dataset.env_args.physics_timestep

# mj_data = mujoco.MjData(mj_model)

# # Initialize MuJoCo model and data structures & place into GPU
# mjx_model = mjx.put_model(mj_model)
# mjx_data = mjx.put_data(mj_model, mj_data)

# all_clips = []
# for n in range(len(mocap_qpos_reshaped)):
#     clip=process_clip(mocap_qpos_reshaped[n],mjx_model,mjx_data,max_qvel=20, dt=1/300)
#     all_clips.append(clip)

In [None]:
mocap_qpos_reshaped[0].shape

In [None]:
all_clips_reference2 = {f'clip{n:04}':clip for n, clip in enumerate(all_clips)}

In [None]:
# clip_names = [f'clip{n:04}' for n in range(len(mocap_qpos_reshaped))]
# all_clips = load_reference_clip_from_h5('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/clips/all_clips_raw2.h5',clip_names)
# save_reference_clip_to_h5('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/clips/all_clips_raw.h5',clip_names=clip_names,reference_clip=all_clips_reference2)

In [None]:
ref_clip = {}
ref_clip['position'] = all_clips.position
ref_clip['quaternion'] = all_clips.quaternion
ref_clip['joints'] = all_clips.joints
ref_clip['body_positions'] = all_clips.body_positions
ref_clip['velocity'] = all_clips.velocity
ref_clip['joints_velocity'] = all_clips.joints_velocity
ref_clip['angular_velocity'] = all_clips.angular_velocity
ref_clip['body_quaternions'] = all_clips.body_quaternions

In [None]:
# reference_path = cfg.paths.data_dir/ f"clips/all_clips_batch_interp.p"
# # reference_path = cfg.paths.data_dir/ f"clips/{cfg.dataset['clip_idx']}"
# reference_path.parent.mkdir(parents=True, exist_ok=True)

# with open(reference_path, "rb") as file:
#     # Use pickle.load() to load the data from the file
#     all_clips = pickle.load(file)

In [None]:
n = 0
jp.zeros((all_clips.body_positions[n].shape)).shape

In [None]:
# body_positions = ref_clip['body_positions']
# bodypos_all = []
# for n in range(len(all_clips.body_positions)):
#     dpos = jp.zeros((all_clips.body_positions[n].shape[0],3))
#     dpos = dpos.at[:,0].set(jp.cumsum(lin_vel_y_cm[n])*env.dt)
#     dpos = dpos.at[:,1].set(jp.cumsum(lin_vel_x_cm[n])*env.dt)
    
#     bodypos_all.append(all_clips.body_positions[n]+dpos[:,None,:])

# for n in range(len(ref_clip['body_positions'])):
#     ref_clip['body_positions'][n] = bodypos_all[n]

In [None]:
qpos_all = []
for n in range(len(all_clips.position)):
    dpos = jp.zeros((all_clips.position[n].shape[0],3))
    dpos = dpos.at[:,0].set(jp.cumsum(lin_vel_y_cm[n])*env.dt)
    dpos = dpos.at[:,1].set(jp.cumsum(lin_vel_x_cm[n])*env.dt)
    
    qpos_all.append(all_clips.position[n]+dpos)

for n in range(len(ref_clip['position'])):
    ref_clip['position'][n] = qpos_all[n]

### Interpolate data

In [None]:
n = 0
key = 'position'
ref_clip_interp = {key:[] for key in ref_clip.keys()}

clip_len = ref_clip[key][n].shape[0]
tmax = 1/300 * clip_len # 1/300 is original mocap hz
t = jp.linspace(0,tmax,clip_len)
t_interp = jp.linspace(0,tmax,1000)
ref_clip_interp[key].append(jp.apply_along_axis(lambda fp,x,xp: jp.interp(x,xp,fp), 0, ref_clip[key][n],x=t_interp,xp=t))
print(key, ref_clip_interp[key][n].shape)

In [None]:
ref_clip_interp = {key:[] for key in ref_clip.keys()}
for key,val in ref_clip.items():
    for n in range(len(val)):
        clip_len = ref_clip[key][n].shape[0]
        tmax = 1/300 * clip_len # 1/300 is original mocap hz
        t = jp.linspace(0,tmax,clip_len)
        t_interp = jp.linspace(0,tmax,1000)
        ref_clip_interp[key].append(jp.apply_along_axis(lambda fp,x,xp: jp.interp(x,xp,fp), 0, ref_clip[key][n],x=t_interp,xp=t))
        print(key, ref_clip_interp[key][n].shape)

In [None]:
def random_limited_quaternion(random, limit):
    """Generates a random quaternion limited to the specified rotations."""
    axis = random.randn(3)
    axis /= np.linalg.norm(axis)
    angle = random.rand() * limit

    quaternion = np.zeros(4)
    mujoco.mju_axisAngle2Quat(quaternion, axis, angle)

    return quaternion

def randomize_limited_and_rotational_joints(mj_model,mj_data, random=None):
    """Randomizes the positions of joints defined in the physics body.

    The following randomization rules apply:
    - Bounded joints (hinges or sliders) are sampled uniformly in the bounds.
    - Unbounded hinges are samples uniformly in [-pi, pi]
    - Quaternions for unlimited free joints and ball joints are sampled
        uniformly on the unit 3-sphere.
    - Quaternions for limited ball joints are sampled uniformly on a sector
        of the unit 3-sphere.
    - The linear degrees of freedom of free joints are not randomized.

    Args:
    physics: Instance of 'Physics' class that holds a loaded model.
    random: Optional instance of 'np.random.RandomState'. Defaults to the global
        NumPy random state.
    """
    random = random or np.random

    hinge = mujoco.mjtJoint.mjJNT_HINGE
    slide = mujoco.mjtJoint.mjJNT_SLIDE
    ball = mujoco.mjtJoint.mjJNT_BALL
    free = mujoco.mjtJoint.mjJNT_FREE

    qpos = mj_data.qpos

    for joint_id in range(mj_model.njnt):
        joint_name = mujoco.mj_id2name(mj_model, mujoco.mju_str2Type("joint"), joint_id)
        joint_type = int(mj_model.joint(joint_id).type)
        is_limited = bool(mj_model.joint(joint_id).limited)
        range_min, range_max = mj_model.joint(joint_id).range
        if is_limited:
            if joint_type == hinge or joint_type == slide:
                qpos[joint_id] = random.uniform(range_min, range_max)

            elif joint_type == ball:
                qpos[joint_id] = random_limited_quaternion(random, range_max)

        else:
            if joint_type == hinge:
                qpos[joint_id] = random.uniform(-np.pi, np.pi)

            elif joint_type == ball:
                quat = random.randn(4)
                quat /= np.linalg.norm(quat)
                qpos[joint_id] = quat

            # elif joint_type == free:
            #     # this should be random.randn, but changing it now could significantly
            #     # affect benchmark results.
            #     quat = random.rand(4)
            #     quat /= np.linalg.norm(quat)
            #     qpos[joint_id+3:joint_id+7] = quat

In [None]:
mujoco.mj_resetData(mj_model, mj_data)
max_con = 12
penetrating = True


In [None]:
end_eff_names = [
    "claw_T1_left",
    "claw_T1_right",
    "claw_T2_left",
    "claw_T2_right",
    "claw_T3_left",
    "claw_T3_right",
]
endeff_idxs = np.array(
    [
        mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("body"), body)
        for body in end_eff_names
    ]
)
# mujoco.mj_resetData(mj_model, mj_data)
clip_idx = 100
qposes_rollout = jp.array(mocap_qpos_reshaped[clip_idx].copy())
# mj_data.qpos = temp_mocap[1].copy()
xpos_all = []
qpos_all = []
frames = []

with mujoco.Renderer(mj_model, height=480, width=480) as renderer:
    for qpos1 in qposes_rollout:
        temp_mocap = qpos1.copy()
        # temp_mocap = temp_mocap.at[3:7].set(jp.array([1,0,0,0]))
        temp_mocap = temp_mocap.at[2].set(0.0317)
        
        mj_data.qpos = temp_mocap.copy()
        # penetrating = True
        # while penetrating: 
        #     # randomize_limited_and_rotational_joints(mj_model,mj_data)
        #     mj_data.qpos[2] += 0.0001
        #     mujoco.mj_forward(mj_model, mj_data)
        #     penetrating = mj_data.ncon > 0
        qpos_all.append(mj_data.qpos.copy())
        # mj_data.qpos = qpos1
        mujoco.mj_forward(mj_model, mj_data)
        xpos_all.append(mj_data.xpos[endeff_idxs].copy())
        renderer.update_scene(mj_data, camera=1, scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)
xpos_all = np.stack(xpos_all)
qpos_all = np.stack(qpos_all)
# media.show_image(pixels)
media.show_video(frames, fps=50)

In [None]:
np.mean(qpos_all[:,2])

In [None]:
temp_mocap = mocap_qpos_reshaped[clip_idx].copy()
temp_mocap[:,2] = -0.1
temp_mocap[:,3:7] = np.array([1,0,0,0])
mj_data.qpos = temp_mocap[1].copy()
mujoco.mj_forward(mj_model, mj_data)
mj_data.ncon

In [None]:
end_eff_names = [
    "claw_T1_left",
    "claw_T1_right",
    "claw_T2_left",
    "claw_T2_right",
    "claw_T3_left",
    "claw_T3_right",
]
endeff_idxs = np.array(
    [
        mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("body"), body)
        for body in end_eff_names
    ]
)
mj_data.xpos[endeff_idxs]

In [None]:
plt.plot(xpos_all[:,:,-1])

In [None]:
# mujoco.mj_resetData(mj_model, mj_data)
# temp_mocap = mocap_qpos_reshaped[clip_idx].copy()
# mj_data.qpos = temp_mocap[1].copy()
with mujoco.Renderer(mj_model, height=480, width=480) as renderer:
# xpos_all = []
# for qpos1 in qposes_rollout:
    # mj_data.qpos = qpos1
    mujoco.mj_forward(mj_model, mj_data)
    # xpos_all.append(mj_data.xpos[endeff_idxs].copy())
    renderer.update_scene(mj_data, camera=1, scene_option=scene_option)
    pixels = renderer.render()
    # frames.append(pixels)
    media.show_image(pixels)
# xpos_all = np.stack(xpos_all)
mj_data.qpos[:3]

In [None]:
spec = mujoco.MjSpec()
spec = spec.from_file(cfg.dataset.env_args.mjcf_path)
mj_model = spec.compile()
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = cfg.dataset.env_args.physics_timestep

mj_data = mujoco.MjData(mj_model)

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True

# save rendering and log to wandb
os.environ["MUJOCO_GL"] = "osmesa"
mujoco.mj_kinematics(mj_model, mj_data)
# renderer = mujoco.Renderer(mj_model, height=512, width=512)
xpos_all = []
# for clip_idx in range(env._n_clips):
clip_idx=27

# qposes_rollout = np.concatenate([all_clips_interp.position[clip_idx],all_clips_interp.quaternion[clip_idx],all_clips_interp.joints[clip_idx]],axis=-1)
temp_mocap = mocap_qpos_reshaped[clip_idx].copy()
temp_mocap[:,2] = .02
temp_mocap[:,3] = 1
temp_mocap[:,4:7] = 0
qposes_rollout = jp.array(temp_mocap)
frames = []
xpos_geoms =[]
# render while stepping using mujoco
# with mujoco.Renderer(mj_model, height=480, width=480) as renderer:
for qpos1 in qposes_rollout:
    mj_data.qpos = qpos1
    mujoco.mj_forward(mj_model, mj_data)
    xpos_geoms.append(mj_data.xpos.copy())
        # renderer.update_scene(mj_data, camera=1, scene_option=scene_option)
        # pixels = renderer.render()
        # frames.append(pixels)
#     xpos_all.append(jp.stack(xpos_geoms))
# xpos_all = jp.stack(xpos_all)
# media.show_video(frames, fps=50)

In [None]:
scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True

qpos1 = env.sys.qpos0
with mujoco.Renderer(mj_model, height=480, width=480) as renderer:
    mj_data.qpos = qpos1
    mujoco.mj_forward(mj_model, mj_data)
    renderer.update_scene(mj_data, camera='side', scene_option=scene_option)
    pixels = renderer.render()
media.show_image(pixels)

In [None]:
ref_clip_interp['body_positions'] = xpos_all

In [None]:
env._endeff_idxs

In [None]:
# plt.plot(xpos_all[0,:,27])
# plt.plot(reference_clip.body_positions[0][:,27])
plt.plot(all_clips_reference.body_positions[0][:,27])
# media.show_video(frames, fps=50)

In [None]:
for key in ref_clip.keys():
    ref_clip_interp[key] = jp.stack(ref_clip_interp[key],axis=0)
    print(key, ref_clip_interp[key].shape)


In [None]:
all_clips_reference = ReferenceClip()
all_clips_reference =all_clips_reference.replace(
    position=jp.array(ref_clip_interp['position']),
    quaternion=jp.array(ref_clip_interp['quaternion']),
    joints=jp.array(ref_clip_interp['joints']),
    body_positions=jp.array(ref_clip_interp['body_positions']),
    velocity=jp.array(ref_clip_interp['velocity']),
    joints_velocity=jp.array(ref_clip_interp['joints_velocity']),
    angular_velocity=jp.array(ref_clip_interp['angular_velocity']),
    body_quaternions=jp.array(ref_clip_interp['body_quaternions']),
)

In [None]:

reference_path = Path(cfg.paths.data_dir)/ "clips/all_clips_batch_interp2.p"
with open(reference_path, "wb") as file:
    # Use pickle.dump() to save the data to the file
    pickle.dump(all_clips_reference, file)

In [None]:
reference_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/clips/all_clips_list_interp.p')
with open(reference_path, "wb") as file:
    # Use pickle.dump() to save the data to the file
    pickle.dump(all_clips, file)

In [None]:
all_clips_reference.position.shape

In [None]:
all_ref_clip = {}
all_ref_clip['position'] = reference_clip.position
all_ref_clip['quaternion'] = reference_clip.quaternion
all_ref_clip['joints'] = reference_clip.joints
all_ref_clip['body_positions'] =reference_clip.body_positions
all_ref_clip['velocity'] = reference_clip.velocity
all_ref_clip['joints_velocity'] = reference_clip.joints_velocity
all_ref_clip['angular_velocity'] = reference_clip.angular_velocity
all_ref_clip['body_quaternions'] = reference_clip.body_quaternions

In [None]:
reference_path = Path(cfg.paths.data_dir)/ f"clips/all_clips_interp.p"
# reference_path = cfg.paths.data_dir/ f"clips/0.p"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)

In [None]:
np.cumsum([reference_clip.position[n].shape[0] for n in range(len(reference_clip.position))])

In [None]:
all_ref_clip = {}
all_ref_clip['position'] = jp.concatenate([clip for clip in reference_clip.position],axis=0)
all_ref_clip['quaternion'] = jp.concatenate([clip for clip in reference_clip.quaternion],axis=0)
all_ref_clip['joints'] = jp.concatenate([clip for clip in reference_clip.joints],axis=0)
all_ref_clip['body_positions'] = jp.concatenate([clip for clip in reference_clip.body_positions],axis=0)
all_ref_clip['velocity'] = jp.concatenate([clip for clip in reference_clip.velocity],axis=0)
all_ref_clip['joints_velocity'] = jp.concatenate([clip for clip in reference_clip.joints_velocity],axis=0)
all_ref_clip['angular_velocity'] = jp.concatenate([clip for clip in reference_clip.angular_velocity],axis=0)
all_ref_clip['body_quaternions'] = jp.concatenate([clip for clip in reference_clip.body_quaternions],axis=0)

In [None]:
all_clips = []
for n in range(len(all_ref_clip['position'])):
    temp_clip = ReferenceClip()
    temp_clip = temp_clip.replace(
        position = all_ref_clip['position'][n],
        quaternion = all_ref_clip['quaternion'][n],
        joints = all_ref_clip['joints'][n],
        body_positions = all_ref_clip['body_positions'][n],
        velocity = all_ref_clip['velocity'][n],
        joints_velocity = all_ref_clip['joints_velocity'][n],
        angular_velocity = all_ref_clip['angular_velocity'][n],
        body_quaternions = all_ref_clip['body_quaternions'][n]
    )
    all_clips.append(temp_clip)

In [None]:
reference_clip.position[5]

In [None]:
reference_path = Path(cfg.paths.data_dir)/ f"clips/all_clips_interp.p"
# reference_path = cfg.paths.data_dir/ f"clips/0.p"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)

In [None]:
clip_lengths = jp.array([reference_clip.position[n].shape[0] for n in range(len(reference_clip.position))])
clip_start_inds = np.concatenate(([0],np.cumsum(clip_lengths)[:-1]),axis=0)

In [None]:
clip_lengths

In [None]:
reference_clip = reference_clip.replace(
    position = all_ref_clip['position'],
    quaternion = all_ref_clip['quaternion'],
    joints = all_ref_clip['joints'],
    body_positions = all_ref_clip['body_positions'],
    velocity = all_ref_clip['velocity'],
    joints_velocity = all_ref_clip['joints_velocity'],
    angular_velocity = all_ref_clip['angular_velocity'],
    body_quaternions = all_ref_clip['body_quaternions'],
)

## Treadmill data

In [None]:
import pandas as pd
base_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/')

tredmill_data = pd.read_csv(base_path/'wt_berlin_linear_treadmill_dataset.csv')
kp_names = ['head', 'thorax', 'abdomen', 'r1', 'r2', 'r3', 'l1', 'l2', 'l3']
coords = ['_x', '_y', '_z']
df_names = [kp+coord for kp in kp_names for coord in coords]
kp_data_all = tredmill_data[df_names].values

sorted_kp_names = kp_names
kp_data = .3*kp_data_all.copy().reshape(1800,-1,27)
belt_speed = tredmill_data['belt speed (mm/s)'].values.reshape(1800,-1)

In [None]:
stac_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/transform_treadmill.p')
# stac_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/transform_mocap_fly_freejnt.p')
# Load mocap data from a file.
with open(stac_path, "rb") as file:
    d = pickle.load(file)
    mocap_qpos = jp.array(d["qpos"])

In [None]:
mocap_qpos.reshape(581,1800,-1).shape

In [None]:
1045800/1800

In [None]:
fly_skel = ((0,1),(1,2),(1,3),(1,4),(1,6),(1,7),(1,8))


In [None]:
kp_data = kp_data.reshape(1800,581,3,-1)

In [None]:
joint_idx = np.array([root.joint(i).id for i in range(root.njnt)])[1:]
end_eff = [
'claw_T1_left',
'claw_T1_right',
'claw_T2_left',
'claw_T2_right',
'claw_T3_left',
'claw_T3_right',
]
endeff_idxs = jp.array(
    [
        mujoco.mj_name2id(root, mujoco.mju_str2Type("body"), body)
        for body in end_eff
    ]
)

In [None]:
model_path = ("/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_fastviz.xml")

spec = mujoco.MjSpec()
spec.from_file(model_path)
thorax0 = spec.find_body("thorax")
mj_model = spec.compile()
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = cfg.dataset.env_args.iterations
mj_model.opt.ls_iterations = cfg.dataset.env_args.ls_iterations
mj_model.opt.timestep = env.sys.mj_model.opt.timestep
mj_data = mujoco.MjData(mj_model)

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]

# save rendering and log to wandb
mujoco.mj_kinematics(mj_model, mj_data)
# renderer = mujoco.Renderer(mj_model, height=512, width=512)
qposes_rollout = mocap_qpos.reshape(581,1800,-1)[0,:500]
qposes_rollout = qposes_rollout.at[:,2].set(.05)
# qposes_rollout = qposes_rollout.at[:,3:7].set(0)
frames = []
# render while stepping using mujoco
with mujoco.Renderer(mj_model, height=512, width=512) as renderer:
    for qpos1 in qposes_rollout:
        mj_data.qpos = qpos1
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera='track1', scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)

In [None]:
media.show_video(frames, fps=10)

In [None]:
t0 = kp_data[0,0,:,:]

for n in range(7):
    plt.plot(t0[0,fly_skel[n]],t0[1,fly_skel[n]])
    plt.scatter(t0[0,fly_skel[n]],t0[1,fly_skel[n]])

In [None]:
for n in range(7):
    plt.plot(kp_data[0,0,0,fly_skel[n,0]],kp_data[0,0,1,fly_skel[n,1]], '.')

# Amputation data

In [None]:
import pandas as pd
base_path = Path('/data/users/eabe/biomech_model/Flybody/datasets/Tuthill_data/')
data_path = base_path / 'Amputation/09302024_fly2_0 R1C1_pose-3d.csv'
df = pd.read_csv(data_path)

In [None]:
df.shape

In [None]:

legs_data = ['R1', 'L2','R2', 'L3','R3']
joints_data = ['A','B','C','D','E']
coords_data = ['_x','_y','_z']
joint_pos_columns = [leg + joint + coord 
                     for leg in legs_data
                     for joint in joints_data 
                     for coord in coords_data]
all_cols = (joint_pos_columns
+['abdomen-tip_x',
'abdomen-tip_y',
'abdomen-tip_z',
'stripe-3_x',
'stripe-3_y',
'stripe-3_z',
'stripe-1_x',
'stripe-1_y',
'stripe-1_z',
'thorax-abdomen_x',
'thorax-abdomen_y',
'thorax-abdomen_z',
'head-thorax_x',
'head-thorax_y',
'head-thorax_z',])

kp_data = df[all_cols].values.reshape(1500,-1, 3)

In [None]:
model_path = ("/home/eabe/Research/MyRepos/Brax-Rodent-Track/assets/fruitfly/fruitfly_fastviz.xml")
spec = mujoco.MjSpec()
spec = spec.from_file(model_path)
mj_model = spec.compile()
init_site_idx = mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("site"), "tracking[coxa_T1_left]"),
data = mujoco.MjData(mj_model)
mujoco.mj_forward(mj_model, data)
# mj_model.find_body("thorax").id
# frame0 += site_pos[0, :]  # Body-coxa T1 left joint is the data origin.


In [None]:
def transform_frame(frame):
    """Transform a single frame from data to model reference frame."""
    # Rotate around z-axis.
    frame = frame[:, [1, 0, 2]]
    frame[:, 1] *= -1
    # Change units mm to cm.
    # frame *= 0.1
    return frame
kp_transform = np.array([transform_frame(frame)+data.xpos[init_site_idx] for frame in kp_data])

In [None]:
data_dict = {'kp_data':kp_data, 'kp_transform':kp_transform,
             'kp_names':all_cols}

In [None]:
all_cols

In [None]:
ioh5.save(base_path / 'Amputation/kp_data_amp.h5', data_dict)

In [None]:
kp_data.shape

In [None]:
len(joint_pos_columns)/3

In [None]:
'abdomen-tip_x',
'abdomen-tip_y',
'abdomen-tip_z',
'stripe-3_x',
'stripe-3_y',
'stripe-3_z',
'stripe-1_x',
'stripe-1_y',
'stripe-1_z',
'thorax-abdomen_x',
'thorax-abdomen_y',
'thorax-abdomen_z',
'head-thorax_x',
'head-thorax_y',
'head-thorax_z',
'l-eye-t_x',
'l-eye-t_y',
'l-eye-t_z',
'l-eye-b_x',
'l-eye-b_y',
'l-eye-b_z',
'r-eye-t_x',
'r-eye-t_y',
'r-eye-t_z',
'r-eye-b_x',
'r-eye-b_y',
'r-eye-b_z',

# Visualize KP data


In [None]:
kp_data = bout_dict['walking_bout{:04}'.format(n)]['orig_xpos']
all_bout_inits = jp.stack([bout_dict['walking_bout{:04}'.format(n)]['orig_xpos'][0] for n in range(0,1000)])

In [None]:
import jax.numpy as jp

def procrustes_jax(X, Y):
    """
    Perform Procrustes analysis to align matrix Y to matrix X.
    
    Parameters:
    X : (N, M) array_like
        The first input matrix.
    Y : (N, M) array_like
        The second input matrix.
    
    Returns:
    d : float
        The residual sum of squared errors.
    Z : (N, M) ndarray
        The matrix Y, transformed to best match X.
    R : (M, M) ndarray
        The rotation applied to Y.
    scale : float
        The scaling factor applied to Y.
    """
    # Center both matrices
    X_centered = X - jp.mean(X, axis=0)
    Y_centered = Y - jp.mean(Y, axis=0)

    # Normalize the Frobenius norm to 1
    X_norm = jp.linalg.norm(X_centered)
    Y_norm = jp.linalg.norm(Y_centered)
    X_centered /= X_norm
    Y_centered /= Y_norm
    
    # u, w, vt = svd(B.T.dot(A).T)
    # R = u.dot(vt)
    # scale = w.sum()
    
    # Compute the optimal rotation matrix using Singular Value Decomposition (SVD)
    U, _, Vt = jp.linalg.svd(jp.dot(X_centered.T, Y_centered).T)
    R = jp.dot(U, Vt)
    
    # Apply the rotation and scaling to Y
    Z = jp.dot(Y_centered, R.T) * X_norm

    # Calculate the residual sum of squared errors
    d = jp.sum(jp.square(X_centered - Z))

    # Return the results
    scale = X_norm / Y_norm
    return d, Z, R, scale
    # return Z


In [None]:
clip_idx = 100
test_bout = bout_dict['walking_bout{:04}'.format(clip_idx)]['orig_xpos']
# test_bout_A = (test_bout - test_bout[:,:1,:])
A = ref_bout.copy() #test_bout[0].copy()
clip_idx = 100
test_bout = bout_dict['walking_bout{:04}'.format(clip_idx)]['orig_xpos']
# test_bout_B = (test_bout - test_bout[:,:1,:])
# B = test_bout[0].copy() 

# test_bout_A.shape, test_bout_B.shape
endeff_kp = test_bout[:,4::5].copy()

In [None]:
all_bot_10 = []
bot_10 = np.percentile(endeff_kp[:,:,2], 10, axis=0)
for n in range(endeff_kp.shape[1]):
    all_bot_10.append(endeff_kp[np.where(endeff_kp[:,n,2] < bot_10[n])[0],n,:])


bot_t = np.concatenate(all_bot_10)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def fit_plane(points):
    """
    Fit a plane to a set of 3D points using Singular Value Decomposition (SVD).
    
    Args:
        points (jax.numpy.ndarray): Array of shape (N, 3), where N is the number of points.
        
    Returns:
        tuple: A tuple containing:
            - normal (jax.numpy.ndarray): The normal vector of the plane (shape: (3,)).
            - d (float): The intercept term of the plane equation.
    """
    # if points.shape[1] != 3:
    #     raise ValueError("Input points must have shape (N, 3).")
    
    # Compute the centroid of the points
    centroid = jp.mean(points, axis=0)
    
    # Center the points around the centroid
    centered_points = points - centroid
    
    # Compute the SVD of the centered points
    _, _, vh = svd(centered_points, full_matrices=False)
    
    # The plane's normal is the singular vector corresponding to the smallest singular value
    normal = vh[-1]
    
    # Compute the intercept term (d) for the plane equation ax + by + cz + d = 0
    d = -jp.dot(normal, centroid)
    
    return normal, d

def plot_plane_and_points(points, normal, d, fig=None, ax=None, color='r', label=''):
    """
    Plot the plane and the 3D points.
    
    Args:
        points (jax.numpy.ndarray): Array of shape (N, 3), the 3D points.
        normal (jax.numpy.ndarray): Normal vector of the plane.
        d (float): Intercept term of the plane equation.
    """
    print('points',points.shape)
    # Create a mesh grid for the plane
    x = jp.linspace(points[:, 0].min(), points[:, 0].max(), 10)
    y = jp.linspace(points[:, 1].min(), points[:, 1].max(), 10)
    x, y = jp.meshgrid(x, y)
    
    # Compute the corresponding z values for the plane equation ax + by + cz + d = 0
    z = (-normal[0] * x - normal[1] * y - d) / normal[2]
    
    # Plot the points and the plane
    if fig is None and ax is None:
        fig = plt.figure(figsize=(10, 7))
        ax = fig.add_subplot(111, projection='3d')
    
    # Plot the points
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=color, label=label+' Points')
    
    # Plot the plane
    ax.plot_surface(x, y, z, alpha=0.5, color=color, label=label + ' Plane')
    # ax.plot_surface(x, y, -.13*jp.ones_like(x), alpha=0.5, color='black', label='Ground Plane')
    
    # Set labels and title
    ax.set_xlabel("X-axis")
    ax.set_ylabel("Y-axis")
    ax.set_zlabel("Z-axis")
    ax.set_title("Fitted Plane and 3D Points")
    ax.legend()
    

def best_fit_affine(source_points, target_points):
    """
    Compute the best-fit affine transformation (rotation, scaling, translation)
    that maps source_points to target_points.
    
    Args:
        source_points (jax.numpy.ndarray): Points on the source plane (shape: Nx3).
        target_points (jax.numpy.ndarray): Corresponding points on the target plane (shape: Nx3).
    
    Returns:
        A (jax.numpy.ndarray): Best-fit affine transformation matrix (3x3).
        t (jax.numpy.ndarray): Translation vector (3,).
    """
    if source_points.shape != target_points.shape or source_points.shape[1] != 3:
        raise ValueError("Both input point sets must have shape (N, 3).")
    
    # Compute centroids of both point sets
    centroid_src = jp.mean(source_points, axis=0)
    centroid_tgt = jp.mean(target_points, axis=0)
    
    # Center the points around their respective centroids
    centered_src = source_points - centroid_src
    centered_tgt = target_points - centroid_tgt
    
    # Compute the cross-covariance matrix
    H = jp.dot(centered_src.T, centered_tgt)
    
    # Compute the SVD of the cross-covariance matrix
    U, S, Vt = jax.scipy.linalg.svd(H)
    
    # Compute the rotation matrix
    R = jp.dot(Vt.T, U.T)
    
    # Ensure a proper rotation (det(R) = 1, avoiding reflection)
    if jp.linalg.det(R) < 0:
        Vt = jp.array(Vt)
        Vt[-1, :] *= -1
        R = jp.dot(Vt.T, U.T)
    
    # Compute the translation vector
    t = centroid_tgt - jp.dot(R, centroid_src)
    
    return R, t

In [None]:
plt.close('all')

In [None]:
source_points = bot_t
target_points = bot_t.copy()
target_points[:,2] = -0.13
source_normal, source_d = fit_plane(source_points)
target_normal, target_d = fit_plane(target_points)

R, t = best_fit_affine(source_points, target_points)
tranformed_points = jp.dot(source_points, R.T) + t

In [None]:
source_points = bot_t
target_points = bot_t.copy()
target_points[:,2] = -0.13
source_normal, source_d = fit_plane(source_points)
target_normal, target_d = fit_plane(target_points)
tranformed_normal, tranformed_d = fit_plane(tranformed_points)


fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
plot_plane_and_points(source_points.reshape(-1,3), source_normal, source_d, fig, ax, color='b', label='Source')
plot_plane_and_points(target_points.reshape(-1,3), target_normal, target_d, fig, ax, color='k', label='ground')
plot_plane_and_points(tranformed_points.reshape(-1,3), tranformed_normal, tranformed_d, fig, ax, color='r', label='ground')



In [None]:
# def visualize_affine_transformation_with_planes(
#     source_points, target_points, transformed_points, source_plane, target_plane, transformed_plane
# ):
#     """
#     Visualize the source points, target points, transformed points, and the planes in 3D.
    
#     Args:
#         source_points (jax.numpy.ndarray): Original source points (Nx3).
#         target_points (jax.numpy.ndarray): Target points (Nx3).
#         transformed_points (jax.numpy.ndarray): Transformed source points (Nx3).
#         source_plane (tuple): Coefficients of the source plane (normal vector and intercept).
#         target_plane (tuple): Coefficients of the target plane (normal vector and intercept).
#     """
#     fig = plt.figure(figsize=(10, 7))
#     ax = fig.add_subplot(111, projection='3d')
    
#     # Plot source points
#     ax.scatter(
#         source_points[:, 0], source_points[:, 1], source_points[:, 2],
#         c='blue', label='Source Points', s=50
#     )
    
#     # Plot target points
#     ax.scatter(
#         target_points[:, 0], target_points[:, 1], target_points[:, 2],
#         c='green', label='Target Points', s=50
#     )
    
#     # Plot transformed source points
#     ax.scatter(
#         transformed_points[:, 0], transformed_points[:, 1], transformed_points[:, 2],
#         c='red', label='Transformed Source Points', marker='x', s=50
#     )
    
#     # Create the source plane
#     source_normal, source_d = source_plane
#     x = jnp.linspace(source_points[:, 0].min(), source_points[:, 0].max(), 10)
#     y = jnp.linspace(source_points[:, 1].min(), source_points[:, 1].max(), 10)
#     x, y = jnp.meshgrid(x, y)
#     z = (-source_normal[0] * x - source_normal[1] * y - source_d) / source_normal[2]
#     ax.plot_surface(x, y, z, alpha=0.3, color='blue', label='Source Plane')
    
#     # Create the target plane
#     target_normal, target_d = target_plane
#     x = jnp.linspace(target_points[:, 0].min(), target_points[:, 0].max(), 10)
#     y = jnp.linspace(target_points[:, 1].min(), target_points[:, 1].max(), 10)
#     x, y = jnp.meshgrid(x, y)
#     z = (-target_normal[0] * x - target_normal[1] * y - target_d) / target_normal[2]
#     ax.plot_surface(x, y, z, alpha=0.3, color='green', label='Target Plane')
    
#     # Create the transformed plane
#     transformed_normal, transformed_d = transformed_plane
#     x = jnp.linspace(transformed_points[:, 0].min(), transformed_points[:, 0].max(), 10)
#     y = jnp.linspace(transformed_points[:, 1].min(), transformed_points[:, 1].max(), 10)
#     x, y = jnp.meshgrid(x, y)
#     z = (-transformed_normal[0] * x - transformed_normal[1] * y - transformed_d) / transformed_normal[2]
#     ax.plot_surface(x, y, z, alpha=0.3, color='red', label='Transformed Plane')
    
#     # Set labels and legend
#     ax.set_xlabel("X-axis")
#     ax.set_ylabel("Y-axis")
#     ax.set_zlabel("Z-axis")
#     ax.set_title("Affine Transformation Visualization with Planes")
#     ax.legend()
    
#     plt.show()


# source_points = jp.array(jp.mean(test_bout[:,0::5],axis=0)).copy()
source_points = jp.array(test_bout[:,4::5]).copy()
# target_points = jp.array(ref_bout[0::5].reshape(-1,3)).copy()
target_points = xpos_all.copy()
# Fit planes to the source and target points
source_normal, source_d = fit_plane(source_points.reshape(-1,3))
target_normal, target_d = fit_plane(target_points.reshape(-1,3))

# Transform source points
# d, Z_all, R, scale = procrustes_jax(ref_bout[0::5].copy(),source_points.copy())
# Z_all = jp.dot(test_bout, R) * scale
# Z_all = np.array(Z_all)
# transformed_points = Z_all[0,::5].copy()
# transformed_normal, transformed_d = fit_plane(transformed_points.reshape(-1,3))

# Plot the points and the plane
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
plot_plane_and_points(source_points.reshape(-1,3), source_normal, source_d, fig, ax, color='blue',label='source')
plot_plane_and_points(target_points.reshape(-1,3), target_normal, target_d, fig, ax, color='green',label='target')
# plot_plane_and_points(transformed_points, transformed_normal, transformed_d, fig, ax, color='red',label='transformed')


In [None]:
xpos_all.shape

In [None]:
source_points.shape

In [None]:

source_points = jp.array(jp.array(test_bout[:,4::5])).copy()
target_points = jp.array(ref_bout[4::5].reshape(-1,3)).copy()
# Fit planes to the source and target points
source_normal, source_d = fit_plane(source_points.reshape(-1,3))
target_normal, target_d = fit_plane(target_points)

# Transform source points
# d, Z_all, R, scale = procrustes_jax(ref_bout[0::5],source_points[0])
# Z_all = jp.dot(test_bout, R) * scale
transformed_points = Z_all[:,4::5].copy()
transformed_normal, transformed_d = fit_plane(transformed_points.reshape(-1,3))

# Plot the points and the plane
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
plot_plane_and_points(source_points.reshape(-1,3), source_normal, source_d, fig, ax, color='blue')
plot_plane_and_points(target_points, target_normal, target_d, fig, ax, color='green')
plot_plane_and_points(transformed_points.reshape(-1,3), transformed_normal, transformed_d, fig, ax, color='red')


In [None]:
import matplotlib as mpl
def map_discrete_cbar(cmap,N):
    cmap = plt.get_cmap(cmap,N+1)
    bounds = np.arange(-.5,N+1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    return cmap, norm
cmap,norm = map_discrete_cbar('turbo',6)
fly_skel = np.array([(n,n+1) for n in range(29)])
fly_skel = np.stack((fly_skel[:4],fly_skel[5:9],fly_skel[10:14],fly_skel[15:19],fly_skel[20:24],fly_skel[25:30],))

Z = Z_all.copy()
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(ref_bout[::5, 0], ref_bout[::5, 1],ref_bout[::5, 2], c=np.arange(ref_bout[::5].shape[0]),cmap=cmap)
# ax.scatter(Z[::5, 0], Z[::5, 1],Z[::5, 2], c=np.arange(Z[::5].shape[0]),cmap=cmap)
ax.scatter(Z[:,4::5, 0], Z[:,4::5, 1],Z[:,4::5, 2],c='r')
for n in range(fly_skel.shape[0]):
    ax.plot(ref_bout[fly_skel[n], 0], ref_bout[fly_skel[n], 1],  ref_bout[fly_skel[n], 2],c='k')
    ax.plot(Z[0,fly_skel[n], 0], Z[0,fly_skel[n], 1],Z[0,fly_skel[n], 2],c='r')
plt.show()

In [None]:
plt.close('all')

In [None]:

d, Z0, R, scale = procrustes_jax(A[::5], B[::5])
# Z1,Z2,d = procrustes(A[::5], B[::5])
Z = jp.dot(B, R) * scale
d

In [None]:
# procrustes_partial = functools.partial(procrustes_jax)
jit_procrustes = jax.jit(procrustes_jax)
vmap_procrustes_partial = jax.vmap(jit_procrustes, (None,0))
d, Z_all, R, scale = vmap_procrustes_partial(ref_bout,all_bout_inits)

In [None]:
d,Z,R,scale = procrustes_jax(A,B)
Z_all = jp.dot(test_bout, R) * scale
Z_all = np.array(Z_all)

In [None]:
import matplotlib as mpl
def map_discrete_cbar(cmap,N):
    cmap = plt.get_cmap(cmap,N+1)
    bounds = np.arange(-.5,N+1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    return cmap, norm

frame_idx = 0
clip_idx = 1000
# test_bout = transform_bout(bout_dict['walking_bout{:04}'.format(clip_idx)]['orig_xpos'])
# Z = test_bout[0]# bout_dict['walking_bout{:04}'.format(clip_idx)]['orig_xpos'][frame_idx, :].reshape(30, 3) # (keypoint, xyz)
fly_skel = np.array([(n,n+1) for n in range(29)])
fly_skel = np.stack((fly_skel[:4],fly_skel[5:9],fly_skel[10:14],fly_skel[15:19],fly_skel[20:24],fly_skel[25:30],))

test_bout = bout_dict['walking_bout{:04}'.format(clip_idx)]['orig_xpos']
A = ref_bout.copy() #test_bout[0].copy()
B = test_bout[frame_idx].copy() 

cmap,norm = map_discrete_cbar('turbo',6)
fig, axs = plt.subplots(1,2,figsize=(12, 4))

ax=axs[0]
ax.scatter(0, 0, marker='*',c='g',zorder=10)
ax.scatter(A[::5, 0], A[::5, 1], c=np.arange(A[::5].shape[0]),cmap=cmap)
ax.scatter(B[::5, 0], B[::5, 1], c=np.arange(B[::5].shape[0]),cmap=cmap)
for n in range(fly_skel.shape[0]):
    ax.plot(A[fly_skel[n], 0], A[fly_skel[n], 1],c='k')
    ax.plot(B[fly_skel[n], 0], B[fly_skel[n], 1],c='r')
ax.axis('equal')
ax.set_xlabel('x (cm)')
ax.set_ylabel('y (cm)')
ax.set_title('top-down view')

ax = axs[1]
ax.scatter(0, 0, marker='*',c='g',zorder=10)
ax.scatter(A[::5, 0], A[::5, 2], c=np.arange(A[::5].shape[0]),cmap=cmap)
ax.scatter(B[::5, 0], B[::5, 2], c=np.arange(B[::5].shape[0]),cmap=cmap)
for n in range(fly_skel.shape[0]):
    ax.plot(A[fly_skel[n], 0], A[fly_skel[n], 2],c='k')
    ax.plot(B[fly_skel[n], 0], B[fly_skel[n], 2],c='r')
ax.axis('equal')
ax.set_xlabel('x (cm)')
ax.set_ylabel('z (cm)')
ax.set_title('side view')
plt.show()

In [None]:
import matplotlib as mpl
def map_discrete_cbar(cmap,N):
    cmap = plt.get_cmap(cmap,N+1)
    bounds = np.arange(-.5,N+1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    return cmap, norm

cmap,norm = map_discrete_cbar('turbo',6)


In [None]:
Z_all = jp.dot(test_bout, R.T) + t

In [None]:
t = 10
A = test_bout[t]
A = A - A[0]
Z = Z_all[t]
Z = Z - Z[0]

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(0, 0, marker='*',c='g',zorder=10)
ax.scatter(A[::5, 0], A[::5, 1],A[::5, 2], c=np.arange(A[::5].shape[0]),cmap=cmap)
ax.scatter(Z[::5, 0], Z[::5, 1],Z[::5, 2], marker='x',c=np.arange(Z[::5].shape[0]),cmap=cmap)

for n in range(fly_skel.shape[0]):
    ax.plot(A[fly_skel[n], 0], A[fly_skel[n], 1],A[fly_skel[n], 2],c='k')
    ax.plot(Z[fly_skel[n], 0], Z[fly_skel[n], 1],Z[fly_skel[n], 2],c='r')

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.axis('equal')

# fig = plt.figure(figsize=(10, 7))
# ax = fig.add_subplot(111, projection='3d')
plot_plane_and_points(source_points.reshape(-1,3), source_normal, source_d, fig, ax, color='b', label='Source')
plot_plane_and_points(target_points.reshape(-1,3), target_normal, target_d, fig, ax, color='k', label='ground')
plot_plane_and_points(tranformed_points.reshape(-1,3), tranformed_normal, tranformed_d, fig, ax, color='r', label='ground')

plt.tight_layout()
plt.show()

In [None]:
import matplotlib as mpl
def map_discrete_cbar(cmap,N):
    cmap = plt.get_cmap(cmap,N+1)
    bounds = np.arange(-.5,N+1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    return cmap, norm

frame_idx = 0
clip_idx = 1000
# test_bout = transform_bout(bout_dict['walking_bout{:04}'.format(clip_idx)]['orig_xpos'])
# Z = test_bout[0]# bout_dict['walking_bout{:04}'.format(clip_idx)]['orig_xpos'][frame_idx, :].reshape(30, 3) # (keypoint, xyz)
fly_skel = np.array([(n,n+1) for n in range(29)])
fly_skel = np.stack((fly_skel[:4],fly_skel[5:9],fly_skel[10:14],fly_skel[15:19],fly_skel[20:24],fly_skel[25:30],))

cmap,norm = map_discrete_cbar('turbo',6)
fig, axs = plt.subplots(1,2,figsize=(12, 4))

# A = transform_frame(A)
# Z = transform_frame(Z)
t=10
A = A - A[0]
Z = Z_all[t]
Z = Z - Z[0]
ax=axs[0]
ax.scatter(0, 0, marker='*',c='g',zorder=10)
ax.scatter(A[::5, 0], A[::5, 1], c=np.arange(A[::5].shape[0]),cmap=cmap)
ax.scatter(Z[::5, 0], Z[::5, 1], marker='x',c=np.arange(Z[::5].shape[0]),cmap=cmap)
for n in range(fly_skel.shape[0]):
    ax.plot(A[fly_skel[n], 0], A[fly_skel[n], 1],c='k')
    ax.plot(Z[fly_skel[n], 0], Z[fly_skel[n], 1],c='r')
ax.axis('equal')
ax.set_xlabel('x (cm)')
ax.set_ylabel('y (cm)')
ax.set_title('top-down view')

ax = axs[1]
# ax.scatter(0, 0, marker='*',c='g',zorder=10)
ax.scatter(A[::5, 0], A[::5, 2], c=np.arange(A[::5].shape[0]),cmap=cmap)
ax.scatter(Z[::5, 0], Z[::5, 2], marker='x',c=np.arange(Z[::5].shape[0]),cmap=cmap)
for n in range(fly_skel.shape[0]):
    ax.plot(A[fly_skel[n], 0], A[fly_skel[n], 2],c='k')
    ax.plot(Z[fly_skel[n], 0], Z[fly_skel[n], 2],c='r')
ax.axis('equal')
ax.set_xlabel('x (cm)')
ax.set_ylabel('z (cm)')
ax.set_title('side view')
plt.show()

In [None]:
def transform_frame(frame):
    """Transform a single frame from data to model reference frame."""
    # Rotate around z-axis.
    frame = frame[:, [1, 0, 2]]
    frame[:, 1] *= -1
    # Change units mm to cm.
    frame *= 0.1
    return frame

In [None]:
import matplotlib as mpl
def map_discrete_cbar(cmap,N):
    cmap = plt.get_cmap(cmap,N+1)
    bounds = np.arange(-.5,N+1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    return cmap, norm

frame_idx = 0
clip_idx = 100
test_bout = transform_bout(bout_dict['walking_bout{:04}'.format(clip_idx)]['orig_xpos'])
frame0 = test_bout[0]# bout_dict['walking_bout{:04}'.format(clip_idx)]['orig_xpos'][frame_idx, :].reshape(30, 3) # (keypoint, xyz)
fly_skel = np.array([(n,n+1) for n in range(29)])
fly_skel = np.stack((fly_skel[:4],fly_skel[5:9],fly_skel[10:14],fly_skel[15:19],fly_skel[20:24],fly_skel[25:30],))

cmap,norm = map_discrete_cbar('turbo',6)
fig, axs = plt.subplots(1,2,figsize=(12, 4))

ax=axs[0]
ax.scatter(frame0[:, 0], frame0[:, 1], c=np.arange(frame0.shape[0]),cmap=cmap)
for n in range(fly_skel.shape[0]):
    ax.plot(frame0[fly_skel[n], 0], frame0[fly_skel[n], 1],c=cmap(n))
ax.axis('equal')
ax.set_xlabel('x (cm)')
ax.set_ylabel('y (cm)')
ax.set_title('top-down view')

ax = axs[1]
ax.scatter(frame0[:, 0], frame0[:, 2], c=np.arange(frame0.shape[0]),cmap=cmap)
for n in range(fly_skel.shape[0]):
    ax.plot(frame0[fly_skel[n], 0], frame0[fly_skel[n], 2],c=cmap(n))
ax.axis('equal')
ax.set_xlabel('x (cm)')
ax.set_ylabel('z (cm)')
ax.set_title('side view')
plt.show()

In [None]:
import matplotlib as mpl
def map_discrete_cbar(cmap,N):
    cmap = plt.get_cmap(cmap,N+1)
    bounds = np.arange(-.5,N+1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    return cmap, norm

frame_idx = 0
clip_idx = 2
frame0 = transform_frame(bout_dict['walking_bout{:04}'.format(clip_idx)]['orig_xpos'][frame_idx, :].reshape(30, 3)) # (keypoint, xyz)
fly_skel = np.array([(n,n+1) for n in range(29)])
fly_skel = np.stack((fly_skel[:4],fly_skel[5:9],fly_skel[10:14],fly_skel[15:19],fly_skel[20:24],fly_skel[25:30],))

cmap,norm = map_discrete_cbar('turbo',6)
fig, axs = plt.subplots(1,2,figsize=(12, 4))

ax=axs[0]
ax.scatter(frame0[:, 0], frame0[:, 1], c=np.arange(frame0.shape[0]),cmap=cmap)
for n in range(fly_skel.shape[0]):
    ax.plot(frame0[fly_skel[n], 0], frame0[fly_skel[n], 1],c=cmap(n))
ax.axis('equal')
ax.set_xlabel('x (mm)')
ax.set_ylabel('y (mm)')
ax.set_title('top-down view')

ax = axs[1]
ax.scatter(frame0[:, 0], frame0[:, 2], c=np.arange(frame0.shape[0]),cmap=cmap)
for n in range(fly_skel.shape[0]):
    ax.plot(frame0[fly_skel[n], 0], frame0[fly_skel[n], 2],c=cmap(n))
ax.axis('equal')
ax.set_xlabel('x (mm)')
ax.set_ylabel('z (mm)')
ax.set_title('side view')
plt.show()

In [None]:
plt.plot(kp_data[0,:,0],kp_data[0,:,1],'.')