# Inverse kinematics fit from the fly hackathon

In [None]:
from os import path as op
import sys
# sys.path.insert(0, op.abspath('/Users/eabe/Research/MyRepos/BiomechControl/models/fruitfly_v2/'))
sys.path.insert(0, op.abspath('/Users/eabe/Research/Github/flybody'))

from pathlib import Path
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import PIL.ImageDraw
from tqdm import tqdm

from dm_control import mujoco
from dm_control import mjcf
from dm_control.mujoco.wrapper.mjbindings import enums
from dm_control.mujoco.wrapper import mjbindings

import sys
sys.path.append('..')

from flybody.inverse_kinematics import qpos_from_site_xpos
from flybody.utils import display_video


In [None]:
def transform_frame(frame):
    """Transform a single frame from data to model reference frame."""
    # Rotate around z-axis.
    frame = frame[:, [1, 0, 2]]
    frame[:, 1] *= -1
    # Change units mm to cm.
    frame *= 0.1
    return frame

# Load a single walking sequence

In [None]:
base_path = Path('/Users/eabe/Research/MyRepos/mujoco_mpc/Archive/fruitfly')
task_path = Path('/Users/eabe/Research/MyRepos/mujoco_mpc/mjpc/tasks/fruitfly/flytrackingqpos')

data_path = base_path / 'combined_wt_berlin_walking_v3.pq'
full_df = pd.read_parquet(data_path, engine='pyarrow')
bout_stats = full_df.groupby(['walking_bout_number','fullfile','Sex'])[['fictrac_delta_rot_lab_y_mms', 'fictrac_delta_rot_lab_z_deg/s']].agg(['mean','min','max','std','count'])
fast = (bout_stats[('fictrac_delta_rot_lab_y_mms','mean')] >= 12) & (bout_stats[('fictrac_delta_rot_lab_y_mms','min')] >= 10)
straight = (bout_stats[('fictrac_delta_rot_lab_z_deg/s','mean')].abs() <= 45) &\
           (bout_stats[('fictrac_delta_rot_lab_z_deg/s','min')] >= -60) &\
           (bout_stats[('fictrac_delta_rot_lab_z_deg/s','max')] <= 60)

#minlen =  bout_stats[('fictrac_delta_rot_lab_z_deg/s','count')] >= 300
bout_stats[fast & straight].sort_values(('fictrac_delta_rot_lab_y_mms','min'))

In [None]:
walking_bout_n = 1
df = full_df[full_df['walking_bout_number'] == 14574]
df.shape

In [None]:
# full_df['walking_bout_number'].unique()
walking_bout_n = 2
df = full_df[full_df['walking_bout_number'] == 2244]
df.shape

In [None]:
for col in df.columns:
    print(col)

In [None]:
# joint_names_df = ['{}{}_{}'.format(leg, joint, axis) for leg in legs_data for joint in ['A', 'B', 'C', 'D'] for axis in ['abduct', 'flex', 'rot'] if '{}{}_{}'.format(leg, joint, axis) in df.columns]
# joint_names_mujoco = ['{}{}_{}'.format(leg, joint, axis) for leg in legs for joint in ['A', 'B', 'C', 'D'] for axis in ['abduct', 'flex', 'rot'] if '{}{}_{}'.format(leg, joint, axis) in df.columns]

In [None]:
joints_to_manipulate = ['coxa_abduct_T1_left',  'coxa_twist_T1_left',   'coxa_T1_left',   'femur_T1_left',  'femur_twist_T1_left',  'tibia_T1_left',    'tarsus_T1_left',
                        'coxa_abduct_T1_right', 'coxa_twist_T1_right',  'coxa_T1_right',  'femur_T1_right', 'femur_twist_T1_right', 'tibia_T1_right',   'tarsus_T1_right',
                        'coxa_abduct_T2_left',  'coxa_twist_T2_left',   'coxa_T2_left',   'femur_T2_left',  'femur_twist_T2_left',  'tibia_T2_left',    'tarsus_T2_left',
                        'coxa_abduct_T2_right', 'coxa_twist_T2_right',  'coxa_T2_right',  'femur_T2_right', 'femur_twist_T2_right', 'tibia_T2_right',   'tarsus_T2_right',
                        'coxa_abduct_T3_left',  'coxa_twist_T3_left',   'coxa_T3_left',   'femur_T3_left',  'femur_twist_T3_left',  'tibia_T3_left',    'tarsus_T3_left',
                        'coxa_abduct_T3_right', 'coxa_twist_T3_right',  'coxa_T3_right',  'femur_T3_right', 'femur_twist_T3_right', 'tibia_T3_right',   'tarsus_T3_right',]

In [None]:
n = 3
plt.plot(qpos[:,n+6])
plt.plot(np.deg2rad(df[joint_names_df[n]].values) - np.deg2rad(df[joint_names_df[n]].values).mean())

# Load the fly model

In [None]:
xml_path = task_path / 'task.xml'
# xml_path = task_path.parent / 'fruitfly_force.xml'
mjcf_model = mjcf.from_path(xml_path)

In [None]:
# Visualize fly as loaded.
physics = mjcf.Physics.from_mjcf_model(mjcf_model)
# physics.model.geom('floor').pos = (0, 0, -.137)
_ = physics.reset()
# retract_wings(physics, prefix='')
# physics.step()
pixels = physics.render(camera_id=1, width=640, height=480)
PIL.Image.fromarray(pixels)

In [None]:
legs = ['T1_left', 'T1_right', 'T2_left', 'T2_right', 'T3_left', 'T3_right']
joints = ['coxa', 'femur', 'tibia', 'tarsus']
xpos_geoms = ['coxa', 'femur', 'tibia', 'tarsus', 'claw']
joint_names = [f'{joint}_{leg}' for leg in legs for joint in joints]
xpos_names = [f'{joint}_{leg}' for leg in legs for joint in xpos_geoms]
# physics.named.data.framepos[pos_names]
site_names = [f'tracking[{joint_name}]' for joint_name in xpos_names]

legs_data = ['L1', 'R1', 'L2','R2', 'L3','R3']
joints_data = ['A','B','C','D','E']
coords_data = ['_x','_y','_z']
joint_pos_columns = [leg + joint + coord 
                     for leg in legs_data
                     for joint in joints_data 
                     for coord in coords_data]

# Generate joint sequences in model and data in matching order

In [None]:
# Model.
# mjcf_model = mjcf.from_path(xml_path)
# legs = ['T1_left', 'T1_right', 'T2_left', 'T2_right', 'T3_left', 'T3_right']
# joints = ['coxa', 'femur', 'tibia', 'tarsus', 'claw']
# joint_names = [f'{joint}_{leg}' for leg in legs for joint in joints]

# Data.
legs_data = ['L1', 'R1', 'L2','R2', 'L3','R3']
joints_data = ['A','B','C','D','E']
coords_data = ['_x','_y','_z']
joint_pos_columns = [leg + joint + coord 
                     for leg in legs_data
                     for joint in joints_data 
                     for coord in coords_data]

In [None]:
for joint in joint_names:
    mjcf_model.find('joint', joint)


# Visualize a single frame

In [None]:
frame_idx = 0
test_frame = df[joint_pos_columns].values[frame_idx, :].reshape(30, 3)  # (keypoint, xyz)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.scatter(test_frame[:, 0], test_frame[:, 1], c=np.arange(30))
plt.axis('equal')
plt.xlabel('x (mm)')
plt.ylabel('y (mm)')
plt.title('top-down view')
plt.subplot(1, 2, 2)
plt.scatter(test_frame[:, 0], test_frame[:, 2], c=np.arange(30))
plt.axis('equal')
plt.xlabel('x (mm)')
plt.ylabel('z (mm)')
plt.title('side view')

# Compare model's initial position with first data frame

In [None]:
# Get the first frame.
frame0 = df[joint_pos_columns].values[0, :].reshape(30, 3)  # (keypoint, xyz)
frame0 = transform_frame(frame0)
site_pos = physics.named.data.site_xpos[site_names]
# Shift to match origin of data and model.
frame0 += site_pos[0, :]  # Body-coxa T1 left joint is the data origin.

In [None]:
plt.scatter(frame0[:, 0], frame0[:, 1], c=np.arange(30), label='data')
plt.scatter(site_pos[:, 0], site_pos[:, 1], marker='x', s=100, c=np.arange(30), label='model')
plt.axis('equal')
plt.title('initial site positions before fitting')
plt.xlabel('x (cm)')
plt.ylabel('y (cm)')
plt.legend()
plt.grid()

# Try fitting one frame

In [None]:
# Manipulate all joints for now.
joints_to_manipulate = [j.name for j in mjcf_model.find_all('joint')]

In [None]:
frame0.shape, len(joints_to_manipulate), len(site_names)

In [None]:
_ = physics.reset()
# retract_wings(physics, prefix='')

# Run IK fit on one frame.
res = qpos_from_site_xpos(physics, site_names, frame0, joints_to_manipulate, inplace=True)

In [None]:
res.qpos

In [None]:
plt.plot(res.qpos)


In [None]:
# Compare site positions after fitting.
site_pos = physics.named.data.site_xpos[site_names]
plt.scatter(frame0[:, 0], frame0[:, 1], c=np.arange(30), label='data')
plt.scatter(site_pos[:, 0], site_pos[:, 1], marker='x', s=100, c=np.arange(30), label='fitted model')
plt.axis('equal')
plt.title('site positions after test-fitting one frame')
plt.xlabel('x (cm)')
plt.ylabel('y (cm)')
plt.legend()
plt.grid()

In [None]:
# #add ball
# texture = mjcf_model.asset.add('texture', rgb1=[.2, .3, .4], rgb2=[.1, .2, .3],
#                                type='2d', builtin='checker', name='groundplane',
#                                width=200, height=200,)
# material = mjcf_model.asset.add('material', name='ballsurface', 
#                                 texrepeat=[2, 2],  # Makes white squares exactly 1x1 length units.
#                                 texuniform=True,
#                                 reflectance=0.2,
#                                 texture=texture)

# # Remove freejoint.
# freejoint = mjcf_model.find('joint', 'free')
# if freejoint is not None:
#     freejoint.remove()
    
# # == Add ball.
# radius = 0.454  # Pick ball radius.
# # Calculate ball position wrt fly given ball radius.
# claw_T1_left = np.array([0.09178167, 0.08813114, -0.12480448])
# ball_x = - 0.05
# ball_z = - np.sqrt(radius**2 
#                    - (claw_T1_left[0] - ball_x)**2 
#                    - claw_T1_left[1]**2) + claw_T1_left[2]
# ball = mjcf_model.worldbody.add('body', name='ball', pos=(ball_x, 0, ball_z))
# ball.add('geom', type='sphere', size=(radius, 0, 0),
#          material=material, density=0.1)  # Density of water in cgs == 1.
# ball_joint = ball.add('joint', name='ball', type='ball')

# # Exclude "surprising collisions".
# for child in mjcf_model.find('body', 'thorax').all_children():
#     if child.tag == 'body':
#         mjcf_model.contact.add('exclude', name=f'thorax_{child.name}',
#                                body1='thorax', body2=child.name)

In [None]:
for n in range(1, 100):
    physics.step()

In [None]:
# physics = mjcf.Physics.from_mjcf_model(mjcf_model)
# physics.model.geom('floor').pos = (0, 0, -.137)
_ = physics.reset()
# Visualize model fitted to first frame.
scene_option = mujoco.wrapper.core.MjvOption()
scene_option.flags[enums.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
pixels = physics.render(camera_id=0, width=640, height=480,scene_option=scene_option)
PIL.Image.fromarray(pixels)

# Try fitting all frames

In [None]:
_ = physics.reset()
# retract_wings(physics, prefix='')

site_pos = physics.named.data.site_xpos[site_names]
joints_to_manipulate = ['coxa_abduct_T1_left',  'coxa_twist_T1_left',   'coxa_T1_left',   'femur_T1_left',  'femur_twist_T1_left',  'tibia_T1_left',    'tarsus_T1_left',
                        'coxa_abduct_T1_right', 'coxa_twist_T1_right',  'coxa_T1_right',  'femur_T1_right', 'femur_twist_T1_right', 'tibia_T1_right',   'tarsus_T1_right',
                        'coxa_abduct_T2_left',  'coxa_twist_T2_left',   'coxa_T2_left',   'femur_T2_left',  'femur_twist_T2_left',  'tibia_T2_left',    'tarsus_T2_left',
                        'coxa_abduct_T2_right', 'coxa_twist_T2_right',  'coxa_T2_right',  'femur_T2_right', 'femur_twist_T2_right', 'tibia_T2_right',   'tarsus_T2_right',
                        'coxa_abduct_T3_left',  'coxa_twist_T3_left',   'coxa_T3_left',   'femur_T3_left',  'femur_twist_T3_left',  'tibia_T3_left',    'tarsus_T3_left',
                        'coxa_abduct_T3_right', 'coxa_twist_T3_right',  'coxa_T3_right',  'femur_T3_right', 'femur_twist_T3_right', 'tibia_T3_right',   'tarsus_T3_right',]

qpos = []
all_frames = []
n_frames = len(df[joint_pos_columns].values)
for i in tqdm(range(n_frames)):
    # Prepare frame for fitting.
    frame = df[joint_pos_columns].values[i, :].reshape(30, 3)  # (keypoint, xyz)
    frame = transform_frame(frame)
    frame += site_pos[0, :]  # Shift to body-coxa T1 left joint.
    all_frames.append(frame)
    # _ = physics.reset()
    # Run IK fit on current frame.
    res = qpos_from_site_xpos(physics, site_names, frame, joints_to_manipulate, inplace=True)
    # Store qpos results.
    qpos.append(res.qpos.copy())
all_frames = np.stack(all_frames)
qpos = np.stack(qpos)


In [None]:
plt.plot(qpos[:,7:])
plt.show()

In [None]:
len(joints_to_manipulate)

In [None]:
# qpos_stand = qpos[0,:] #np.array([0,0,0,0,0,0,0,0, 0, 0.000178521, -1.07906e-05, 0, 0.00022365, 0.00012099, 0, 0, 0, 0, 0, 0, -8.21227e-05, 1.7515e-05, 0, -6.83271e-05, -1.73481e-05, 0, 0, 0, 0, 0, 0, -7.29514e-05, 8.13461e-05, 0, -4.55924e-05, 1.73673e-05, 0, 0, 0, 0, 0, 0, -0.000200893, 0.000243786, 0, -9.05341e-05, -1.9564e-05, 0, 0, 0, 0, 0, 0, 0.000307064, -0.000344922, 0, 2.74459e-05, 5.30669e-06, 0, 0, 0, 0, 0, 0, 0.000162806, -7.42181e-05, 0, 0.000232139, 7.99452e-05, 0, 0, 0, 0])
# qvel_stand = np.zeros((qpos_stand.shape[0]-1))
qvel = np.diff(qpos, axis=0, prepend=qpos[0:1,:])
qvel = qvel[:,1:]

In [None]:
interp_time = np.linspace(0,all_frames.shape[0],10*all_frames.shape[0])
time = np.linspace(0,all_frames.shape[0],all_frames.shape[0])
mod_frames = np.zeros((interp_time.shape[0],all_frames.shape[1],all_frames.shape[2]))
qpos_frames = np.zeros((interp_time.shape[0],qpos.shape[1]))
qvel_frames = np.zeros((interp_time.shape[0],qvel.shape[1]))
for n in range(all_frames.shape[1]):
    for m in range(all_frames.shape[2]):
        mod_frames[:,n,m] = np.interp(interp_time, time, all_frames[:,n,m])
for n in range(qpos.shape[1]):
    qpos_frames[:,n] = np.interp(interp_time, time, qpos[:,n])
for n in range(qvel.shape[1]):    
    qvel_frames[:,n] = np.interp(interp_time, time, qvel[:,n])


In [None]:
mod_frames = all_frames
qpos_frames = qpos
qvel_frames = qvel

In [None]:
# mod_frames = all_frames.copy()
mod_frames[:,:,-1] = mod_frames[:,:,-1]+0.021
mod_frames[:,:,-1] = np.clip(mod_frames[:,:,-1],-0.135,10)
print(np.min(mod_frames[:,:,-1]),np.max(mod_frames[:,:,-1]))

In [None]:
root = mjcf.RootElement()
root.keyframe.add('key', name=f'walk{walking_bout_n}_1', mpos=mod_frames[0].flatten(),qpos=qpos_frames[0].flatten(),qvel=qvel_frames[0].flatten())
for n in tqdm(range(2,mod_frames.shape[0])):
    root.keyframe.add('key', name=f'walk{walking_bout_n}_{n}', mpos=mod_frames[n].flatten(),qpos=qpos_frames[n].flatten())
    
# save xml file
keyframe_path = task_path / 'keyframes/'
mjcf.export_with_assets(root, keyframe_path, f'Fly_walk_free_qpos{walking_bout_n}.xml')
keyframe_path/f'Fly_walk_qpos{walking_bout_n}.xml'

In [None]:
root = mjcf.RootElement()
root.keyframe.add('key', name=f'walk{walking_bout_n}_1', mpos=mod_frames[0].flatten(),qpos=qpos[0].flatten(),qvel=qvel.flatten())
for n in range(2,mod_frames.shape[0]):
    root.keyframe.add('key', name=f'walk{walking_bout_n}_{n}', mpos=mod_frames[n].flatten())
    
# save xml file
keyframe_path = task_path / 'keyframes/'
mjcf.export_with_assets(root, keyframe_path,f'Fly_walk_free_pos{walking_bout_n}.xml')
keyframe_path/f'Fly_walk_pos{walking_bout_n}.xml'

In [None]:
_ = physics.reset()
t = 0
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 1]
physics.data.mocap_pos = mod_frames[t]
physics.data.qpos = qpos[t]
for t in range(10):
    physics.step()
pixels = physics.render(camera_id=1, width=640, height=480, scene_option=scene_option)
# qpos_stand = physics.named.data.qpos
# qvel_stand = physics.named.data.qvel
PIL.Image.fromarray(pixels)

In [None]:
np.set_printoptions(suppress=True)
# print(["{:.02f}".format(qqpos) for qqpos in physics.data.qpos])

In [None]:
# Generate video of fitted poses.

camera_id = 1 # Side view.

# Hide all sites.
scene_option = mujoco.wrapper.core.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 0, 0, 0]
scene_option.flags[enums.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True


#retract_wings(physics)

video_frames = []
for q in tqdm(qpos):
    with physics.reset_context():
        physics.data.qpos = q.copy()
    pixels = physics.render(camera_id=camera_id, width=640, height=480, scene_option=scene_option)
    video_frames.append(pixels)

In [None]:
display_video(video_frames, framerate=30)

In [None]:
# Visualize fly standing
physics.reset()

scene_option = mujoco.wrapper.core.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 1]
scene_option.flags[enums.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
physics.data.qpos = np.zeros(qpos.shape[1])
# for n in range(1):
#     physics.step()

mod_frames_stand = np.stack([physics.named.data.xpos[xpos_names[n]] for n in range(len(xpos_names))])[None,:,:]
physics.data.mocap_pos = mod_frames_stand[0]

scene_option = mujoco.wrapper.core.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 1]
scene_option.flags[enums.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
for n in range(1):
    physics.step()
    
pixels = physics.render(camera_id=1, width=640, height=480,scene_option=scene_option)
PIL.Image.fromarray(pixels)

In [None]:
t = 0
qpos_stand = physics.named.data.qpos
qvel_stand = np.zeros((qpos_stand.shape[0]-1))
root = mjcf.RootElement()
root.keyframe.add('key', name=f'stand{walking_bout_n}_1', mpos=mod_frames_stand[0].flatten(),qpos=qpos_stand.flatten(),qvel=qvel_stand.flatten())
for n in range(2,100):
    root.keyframe.add('key', name=f'stand{walking_bout_n}_{n}', mpos=mod_frames_stand[0].flatten(),qpos=qpos_stand.flatten())
    
# save xml file
root.default.remove()
keyframe_path = task_path / 'keyframes/'
mjcf.export_with_assets(root, keyframe_path,f'Fly_stand_free_qpos{walking_bout_n}.xml')
keyframe_path/f'Fly_stand_qpos{walking_bout_n}.xml'

In [None]:
plt.plot(physics.data.xpos[:,0], physics.data.xpos[:,1], 'o')
plt.plot(mod_frames_stand[0,:,0], mod_frames_stand[0,:,1], 'x')