In [1]:
import sys
sys.path.append('/home/diego/code/olveczky/dm/stac')
import compute_stac
import view_stac
import util
import stac
import rodent_environments
import numpy as np
import scipy.optimize
import pickle
from dm_control import viewer
from dm_control.mujoco.wrapper.mjbindings import mjlib

In [37]:
data_path = "/home/diego/data/dm/stac/snippets/JDM25_v5/reformatted/snippet_49_Walk.mat" 
param_path = "/home/diego/code/olveczky/dm/stac/params/june3/JDM25.yaml" 
offset_path = "/home/diego/data/dm/stac/offsets/JDM25_m_9_NoHip.p"

kw = {"offset_path": offset_path,
      "start_frame": 0,
      "n_frames": None,
      "n_sample_frames": 50,
      "verbose": True,
      "skip": 2,
      "adaptive_z_offset": True,
      "visualize": False,
      "render_video": False}
params = util.load_params(param_path)
for key, v in kw.items():
    params[key] = v
params['_LIMB_FTOL'] /= 10
# params['_FTOL'] *= 10
params['_XML_PATH'] = "/home/diego/code/olveczky/dm/stac/models/rat_june3.xml"
data, kp_names, behavior, com_vel = util.load_snippets_from_file(data_path)
kp_data = compute_stac.preprocess_snippet(data, kp_names, params)

In [39]:
q_frames = []
kp_frames = []

def q_loss(q, physics, kp_data, sites, params, qs_to_opt=None, q_copy=None,
           reg_coef=0., root_only=False, temporal_regularization=False,
           q_prev=None, q_next=None):
    """Compute the marker loss for q_phase optimization.

    :param physics: Physics of current environment.
    :param kp_data: Reference
    :param sites: sites of keypoints at frame_index
    :param params: Animal parameters dictionary
    :param qs_to_opt: Binary vector of qposes to optimize.
    :param q_copy: Copy of current qpos, for use in optimization of subsets
                   of qpos.
    :param reg_coef: L1 regularization coefficient during marker loss.
    :param root_only: If True, only regularize the root.
    :param temporal_regularization: If True, regularize arm joints over time.
    :param q_prev: Copy of previous qpos frame for use in
                   bidirectional temporal regularization.
    :param q_next: Copy of next qpos frame for use in bidirectional temporal
                   regularization.
    """
    if temporal_regularization:
        error_msg = ' cannot be None if using temporal regularization'
        if qs_to_opt is None:
            raise _TestNoneArgs('qs_to_opt' + error_msg)
        if q_prev is None:
            raise _TestNoneArgs('q_prev' + error_msg)
        if q_next is None:
            raise _TestNoneArgs('q_next' + error_msg)

    # Optional regularization.
    reg_term = reg_coef * np.sum(q[7:]**2)

    # If only optimizing the root, set everything else to 0.
    if root_only:
        q[7:] = 0.

    # If optimizing arbitrary sets of qpos, add the optimizer qpos to the copy.
    if qs_to_opt is not None:
        q_copy[qs_to_opt] = q
        q = np.copy(q_copy)
        
    q_frames.append(q.copy())
    kp_frames.append(kp_data.copy())

    # Add temporal regularization for arms.
    temp_reg_term = 0.
    if temporal_regularization:
        temp_reg_term += (q[qs_to_opt] - q_prev[qs_to_opt])**2
        temp_reg_term += (q[qs_to_opt] - q_next[qs_to_opt])**2

    residual = (kp_data.T - stac.q_joints_to_markers(q, physics, sites))
    return (.5 * np.sum(residual**2) + reg_term +
            params['temporal_reg_coef'] * temp_reg_term)


def q_phase(physics, marker_ref_arr, sites, params, reg_coef=0.,
            qs_to_opt=None, root_only=False, temporal_regularization=False,
            q_prev=None, q_next=None):
    """Update q_pose using estimated marker parameters.

    :param physics: Physics of current environment.
    :param marker_ref_arr: Keypoint data reference
    :param sites: sites of keypoints at frame_index
    :param params: Animal parameters dictionary
    :param reg_coef: L1 regularization coefficient during marker loss.
    :param qs_to_opt: Binary vector of qs to optimize.
    :param root_only: If True, only optimize the root.
    :param temporal_regularization: If True, regularize arm joints over time.
    """
    lb = np.concatenate(
        [-np.inf * np.ones(7), physics.named.model.jnt_range[1:][:, 0]])
    lb = np.minimum(lb, 0.0)
    ub = np.concatenate(
        [np.inf * np.ones(7), physics.named.model.jnt_range[1:][:, 1]])

    # Define initial position of the optimization
    q0 = np.copy(physics.named.data.qpos[:])
    q_copy = np.copy(q0)

    # Set the center to help with finding the optima
    # TODO(centering_bug):
    # The center is not necessarily from 12:15 depending on struct ordering.
    # This probably won't be a problem, as it is just an ititialization for the
    # optimizer, but keep it in mind.
    if root_only:
        q0[:3] = marker_ref_arr[12:15]

    # If you only want to optimize a subset of qposes,
    # limit the optimizer to that
    if qs_to_opt is not None:
        q0 = q0[qs_to_opt]
        lb = lb[qs_to_opt]
        ub = ub[qs_to_opt]

    # Use different tolerances for root vs normal optimization
    if root_only:
        ftol = params['_ROOT_FTOL']
    elif qs_to_opt is not None:
        ftol = params['_LIMB_FTOL']
    else:
        ftol = params['_FTOL']
    q_opt_param = scipy.optimize.least_squares(
        lambda q: q_loss(q, physics, marker_ref_arr, sites, params,
                         qs_to_opt=qs_to_opt,
                         q_copy=q_copy,
                         reg_coef=reg_coef,
                         root_only=root_only,
                         temporal_regularization=temporal_regularization,
                         q_prev=q_prev,
                         q_next=q_next),
        q0, bounds=(lb, ub), ftol=ftol, diff_step=params['_DIFF_STEP'],
        verbose=0)

    # Set pose to the optimized q and step forward.
    if qs_to_opt is None:
        physics.named.data.qpos[:] = q_opt_param.x
    else:
        q_copy[qs_to_opt] = q_opt_param.x
        physics.named.data.qpos[:] = q_copy.copy()

    mjlib.mj_kinematics(physics.model.ptr, physics.data.ptr)


def q_clip_iso(env, params):
    """Perform q_phase over the entire clip.

    Optimizes limbs and head independently.
    Perform bidirectional temporal regularization.
    :param env: Rodent environment.
    :param params: Rodent parameters.
    """
    q = []
    walker_body_sites = []
    r_leg = compute_stac._get_part_ids(env, ['hip_R', 'knee_R'])
    l_leg = compute_stac._get_part_ids(env, ['hip_L', 'knee_L'])
    r_arm = compute_stac._get_part_ids(env, ['scapula_R', 'shoulder_R', 'elbow_R'])
    l_arm = compute_stac._get_part_ids(env, ['scapula_L', 'shoulder_L', 'elbow_L'])
    head = compute_stac._get_part_ids(env, ['atlas', 'cervical', 'atlant_extend', ])
    for i in range(params['n_frames']):
        print(i)
        # First optimize over all points to get gross estimate and trunk
        q_phase(env.physics, env.task.kp_data[i, :],
                     env.task._walker.body_sites, params,
                     reg_coef=params['q_reg_coef'])

        # Next optimize over the limbs individually to improve time and accur.
        for part in [r_leg, l_leg, r_arm, l_arm, head]:
            q_phase(env.physics, env.task.kp_data[i, :],
                         env.task._walker.body_sites, params,
                         reg_coef=params['q_reg_coef'],
                         qs_to_opt=part)
        q.append(np.copy(env.physics.named.data.qpos[:]))
        walker_body_sites.append(
            np.copy(env.physics.bind(env.task._walker.body_sites).xpos[:])
        )
    print(len(q))
    # Bidirectional temporal regularization
    for i in range(1, params['n_frames'] - 1):
        # Set model state to current frame
        env.physics.named.data.qpos[:] = q[i]

        # Recompute position of select parts with bidirectional
        # temporal regularizer.
        for part in [r_arm, l_arm, r_leg, l_leg]:
            q_phase(env.physics, env.task.kp_data[i, :],
                         env.task._walker.body_sites, params,
                         reg_coef=params['q_reg_coef'],
                         qs_to_opt=part, temporal_regularization=True,
                         q_prev=q[i - 1],
                         q_next=q[i + 1])

            # Update the parts for the current frame
            q[i][part] = np.copy(env.physics.named.data.qpos[:][part])
        walker_body_sites[i] = \
            np.copy(env.physics.bind(env.task._walker.body_sites).xpos[:])
    return q, walker_body_sites

def root_optimization(env, params):
    """Optimize only the root."""
    q_phase(env.physics, env.task.kp_data[0, :],
                 env.task._walker.body_sites, params, root_only=True)
    
def render_stac_animation(kp_data, params):
    if params['n_frames'] is None:
        params['n_frames'] = kp_data.shape[0]
    params['n_frames'] = int(params['n_frames'])
    # Build the environment
    env = rodent_environments.rodent_mocap(kp_data, params)

    # Get the ids of the limbs
    part_names = env.physics.named.data.qpos.axes.row.names
    for i in range(6):
        part_names.insert(0, part_names[0])

    limbs = np.array([any(part in name for part in params['_IS_LIMB'])
                      for name in part_names])

    # If preloading offsets, set them now.
    if params['offset_path'] is not None:
        with open(params['offset_path'], 'rb') as f:
            in_dict = pickle.load(f)

        sites = env.task._walker.body_sites
        env.physics.bind(sites).pos[:] = in_dict['offsets']

        for id, p in enumerate(env.physics.bind(sites).pos):
            sites[id].pos = p

        if params['verbose']:
            print('Root Optimization', flush=True)
        root_optimization(env, params)
            
        # Q_phase optimization
    if params['verbose']:
        print('q-phase', flush=True)
#     q, walker_body_sites = q_clip_iso(env, limbs, params)
    q, walker_body_sites = q_clip_iso(env, params)

    
def animate_qpos(kp_data, q, offset_path, save_path=None, render_video=False, headless=False):
    with open(offset_path, 'rb') as f:
        in_dict = pickle.load(f)
        offsets = in_dict['offsets']
    # Build the environment, and set the offsets, and params
    env = rodent_environments.rodent_mocap(kp_data, params)
    sites = env.task._walker.body_sites
    env.physics.bind(sites).pos[:] = offsets
    for id, site in enumerate(sites):
        site.pos = offsets[id, :]
    env.task.precomp_qpos = q
    env.task.render_video = render_video
    if save_path is not None:
        env.task.video_name = save_path
        print('Rendering: ', env.task.video_name)

    # Render a video in headless mode
    prev_time = env.physics.time()
    if headless & render_video:
        while prev_time < env._time_limit:
            while (env.physics.time() - prev_time) < params['_TIME_BINS']:
                env.physics.step()
            env.task.after_step(env.physics, None)
            prev_time = env.physics.time()

    # Otherwise, use the viewer
    else:
        viewer.launch(env)
    if env.task.V is not None:
        env.task.V.release()

In [47]:
q_frames = []
kp_frames = []
params['n_frames'] = 7
render_stac_animation(kp_data, params)

Root Optimization
q-phase
0
1
2
3
4
5
6
7


In [15]:
lim = 15000
fr = 50
q = [q for i, q in enumerate(q_frames) if np.mod(i,fr) == 0 and i < lim]
kp = [kp for i, kp in enumerate(kp_frames) if np.mod(i,fr) == 0 and i < lim]

first_frame = kp_data[0, :].copy()
kp_frozen = np.zeros((len(q), kp_data.shape[1]))
for i in range(kp_frozen.shape[0]):
    kp_frozen[i,:] = kp[i]
params['n_frames'] = len(q)-1
print(kp_frozen.shape)
print(len(q_frames))

(300, 60)
50498


In [16]:
import imp 
imp.reload(rodent_environments)
save_path = 'qphase_root_animation_indiv.mp4'
headless=True
render_video=False
animate_qpos(kp_frozen, q, offset_path, headless=headless, render_video=render_video, save_path=save_path)

Rendering:  qphase_root_animation_indiv.mp4


In [68]:
# inds = np.round(np.logspace(0,np.log10(40000), base=10, num=1000)).astype('int32') - 1
# print(inds)
lb = 46000
ub = 52250
fr = 5
q = [q for i, q in enumerate(q_frames) if np.mod(i,fr) == 0 and i > lb and i < ub]
kp = [kp for i, kp in enumerate(kp_frames) if np.mod(i,fr) == 0 and i > lb and i < ub]

first_frame = kp_data[0, :].copy()
kp_frozen = np.zeros((len(q), kp_data.shape[1]))
for i in range(kp_frozen.shape[0]):
    kp_frozen[i,:] = kp[i]
params['n_frames'] = len(q)-1
print(kp_frozen.shape)
print(len(q_frames))

(1259, 60)
88606


In [69]:
import imp 
imp.reload(rodent_environments)
save_path = 'qphase_body_limbs_animation_slow_first_frame.mp4'
headless=True
render_video=False
animate_qpos(kp_frozen, q, offset_path, headless=headless, render_video=render_video, save_path=save_path)

Rendering:  qphase_body_limbs_animation_slow_first_frame.mp4


In [None]:
# inds = np.round(np.logspace(0,np.log10(40000), base=10, num=1000)).astype('int32') - 1
# print(inds)
lb = 50000
ub = 200000
fr = 5
q = [q for i, q in enumerate(q_frames) if np.mod(i,fr) == 0 and i > lb and i < ub]
kp = [kp for i, kp in enumerate(kp_frames) if np.mod(i,fr) == 0 and i > lb and i < ub]

first_frame = kp_data[0, :].copy()
kp_frozen = np.zeros((len(q), kp_data.shape[1]))
for i in range(kp_frozen.shape[0]):
    kp_frozen[i,:] = kp[i]
params['n_frames'] = len(q)-1
print(kp_frozen.shape)
print(len(q_frames))

In [None]:
import imp 
imp.reload(rodent_environments)
save_path = 'qphase_body_limbs_animation_slow_first_frame.mp4'
headless=True
render_video=False
animate_qpos(kp_frozen, q, offset_path, headless=headless, render_video=render_video, save_path=save_path)