<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Detections-only" data-toc-modified-id="Detections-only-1">Detections only</a></span></li><li><span><a href="#GIMBAL-only" data-toc-modified-id="GIMBAL-only-2">GIMBAL only</a></span></li><li><span><a href="#Stages-videos" data-toc-modified-id="Stages-videos-3">Stages videos</a></span></li></ul></div>

In [9]:
import jax
import sys
import glob
# import gimbal
import joblib
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import joblib, json, os, h5py
from os.path import join, exists
import imageio, cv2
import tqdm.auto as tqdm
import matplotlib.pyplot as plt
import multicam_calibration as mcc
import keypoint_moseq as kpms
from scipy.ndimage import median_filter, gaussian_filter1d
from vidio.read import OpenCVReader

In [2]:
import networkx as nx
def build_node_hierarchy(bodyparts, skeleton, root_node):
    """
    Define a rooted hierarchy based on the edges of a spanning tree.

    Parameters
    ----------
    bodyparts: list of str
        Ordered list of node names.

    skeleton: list of tuples
        Edges of the spanning tree as pairs of node names.

    root_node: str
        The desired root node of the hierarchy

    Returns
    -------
    node_order: array of shape (num_nodes,)
        Integer array specifying an ordering of nodes in which parents
        precede children (i.e. a topological ordering).

    parents: array of shape (num_nodes,)
        Child-parent relationships using the indexes from `node_order`, 
        such that `parent[i]==j` when `node_order[j]` is the parent of 
        `node_order[i]`.

    Raises
    ------
    ValueError
        The edges in `skeleton` do not define a spanning tree.     
    """
    G = nx.Graph()
    G.add_nodes_from(bodyparts)
    G.add_edges_from(skeleton)

    if not nx.is_tree(G):
        cycles = list(nx.cycle_basis(G))
        raise ValueError(
            'The skeleton does not define a spanning tree, '
            'as it contains the following cycles: {}'.format(cycles))
    
    if not nx.is_connected(G):
        raise ValueError(
            'The skeleton does not define a spanning tree, '
            'as it contains multiple connected components.')
    
    node_order = list(nx.dfs_preorder_nodes(G, root_node))
    parents = np.zeros(len(node_order), dtype=int)

    for i,j in skeleton:
        i,j = node_order.index(i), node_order.index(j)
        if i<j: parents[j] = i
        else: parents[i] = j

    node_order = np.array([bodyparts.index(n) for n in node_order])
    return node_order, parents

In [5]:
import yaml
# from keypoint_sort.util import build_node_hierarchy

bodyparts = ['tail_tip',
             'tail_base',
             'spine_low',
             'spine_mid',
             'spine_high',
             'left_ear',
             'right_ear',
             'forehead',
             'nose_tip',
             'left_hind_paw_front',
             'left_hind_paw_back',
             'right_hind_paw_front',
             'right_hind_paw_back',
             'left_fore_paw',
             'right_fore_paw']

skeleton = [
    ['tail_base', 'spine_low'],
    ['spine_low', 'spine_mid'],
    ['spine_mid', 'spine_high'],
    ['spine_high', 'left_ear'],
    ['spine_high', 'right_ear'],
    ['spine_high', 'forehead'],
    ['forehead', 'nose_tip'],
    ['left_hind_paw_back', 'left_hind_paw_front'],
    ['spine_low', 'left_hind_paw_back'],
    ['right_hind_paw_back', 'right_hind_paw_front'],
    ['spine_low', 'right_hind_paw_back'],
    ['spine_high', 'left_fore_paw'],
    ['spine_high', 'right_fore_paw']
]


use_bodyparts = bodyparts[1:]
use_bodyparts_ix = np.array([bodyparts.index(bp) for bp in use_bodyparts])
edges = np.array(kpms.get_edges(use_bodyparts, skeleton))
node_order, parents = build_node_hierarchy(use_bodyparts, skeleton, 'spine_low')
edges = np.argsort(node_order)[edges]

### Detections only

In [71]:
vid_dir = '/n/groups/datta/Jonah/kpms_reviews_6cam_thermistor/raw_data/J01601/20230904_J01601'
vid_paths = glob.glob(vid_dir+'/*.avi')

all_uvs = []
all_confs = []
for p in vid_paths:
    kp_path = p.replace('.avi','.keypoints.h5')
    with h5py.File(kp_path,'r') as h5:
        uvs = h5['uv'][()]
        confs = h5['conf'][()]
        uvs[confs<0.25] = np.nan
        all_uvs.append(uvs)
        all_confs.append(confs)
        


centroids = []
for uvs in all_uvs:
    cen = np.nanmedian(uvs,axis=1)[:,None,:]
    cen = kpms.interpolate_keypoints(cen, np.isnan(cen).all(1)[:,None]).squeeze()
    cen = gaussian_filter1d(cen, 10, axis=0)
    centroids.append(cen)
    
readers = [OpenCVReader(p) for p in vid_paths]

outpath = f'qc_videos/{vid_dir.split("/")[-1]}.detections.mp4'
with imageio.get_writer(outpath, pixelformat="yuv420p", fps=30, quality=5) as writer:
    for t in tqdm.trange(15000):
        overlays = []
        for i in range(len(readers)):
            im = kpms.overlay_keypoints_on_image(readers[i][t], all_uvs[i][t], edges)
            im = kpms.crop_image(im, centroids[i][t], 512)
            overlays.append(im[::2,::2])
        image = np.vstack([
            np.hstack(overlays[:3]),  
            np.hstack(overlays[3:])
        ])
        image = cv2.putText(
            image, f"{t}", (10, image.shape[0]-10),
            cv2.FONT_HERSHEY_SIMPLEX, 0.9,
            (255,255,255), 2, cv2.LINE_AA
        )
        writer.append_data(image)

  0%|          | 0/15000 [00:00<?, ?it/s]

### GIMBAL only

In [10]:
# vid_dir = '/n/groups/datta/Jonah/kpms_reviews_6cam_thermistor/raw_data/J01601/20230904_J01601'
# calib_path = '/n/groups/datta/Jonah/kpms_reviews_6cam_thermistor/raw_data/calibration/data/20230904_calibration/camera_params.h5'


vid_dir = '/n/groups/datta/Jonah/kpms_reviews_6cam_thermistor/raw_data/J01701/20230912_J01701'
calib_path = '/n/groups/datta/Jonah/kpms_reviews_6cam_thermistor/raw_data/calibration/data/20230912_calibration/camera_params.h5'

all_extrinsics, all_intrinsics, camera_names = mcc.load_calibration(calib_path, 'gimbal')
gimbal_positions = median_filter(np.load(f'{vid_dir}/gimbal.npy'),(5,1,1))
gimbal_uvs = [mcc.project_points(gimbal_positions, ext, *intr) for ext,intr in zip(all_extrinsics, all_intrinsics)]
centroids = gaussian_filter1d(np.mean(gimbal_uvs,axis=2),10,axis=1)
    
readers = [OpenCVReader(f'{vid_dir}/{c}.avi') for c in camera_names]
output_dir = join(vid_dir, "qc_videos")
if not exists(output_dir): os.mkdir(output_dir)
output_path = join(output_dir, f'{vid_dir.split("/")[-1]}.gimbal.mp4')

with imageio.get_writer(output_path, pixelformat="yuv420p", fps=30, quality=5) as writer:
    for t in tqdm.trange(25000):
        overlays = []
        for i in range(len(readers)):
            im = kpms.overlay_keypoints_on_image(readers[i][t], gimbal_uvs[i][t], edges)
            im = kpms.crop_image(im, centroids[i][t], 384)
            overlays.append(im)
        image = np.vstack([
            np.hstack(overlays[:3]),  
            np.hstack(overlays[3:])
        ])
        image = cv2.putText(
            image, f"{t}", (10, image.shape[0]-10),
            cv2.FONT_HERSHEY_SIMPLEX, 0.9,
            (255,255,255), 2, cv2.LINE_AA
        )
        writer.append_data(image)

  0%|          | 0/25000 [00:00<?, ?it/s]

### Stages videos

In [10]:
calib_path = '/n/groups/datta/Jonah/kpms_reviews_6cam_thermistor/raw_data/calibration/data/20230904_calibration/camera_params.h5'
vid_dir = '/n/groups/datta/Jonah/kpms_reviews_6cam_thermistor/raw_data/J01601/20230904_J01601'

all_extrinsics, all_intrinsics, camera_names = mcc.load_calibration(calib_path, 'gimbal')
triang_positions = np.load(f'{vid_dir}/robust_triangulation.npy')[:,use_bodyparts_ix][:,node_order]
triang_uvs = [mcc.project_points(triang_positions, ext, *intr) for ext,intr in zip(all_extrinsics, all_intrinsics)]
gimbal_positions = median_filter(np.load(f'{vid_dir}/gimbal.npy'),(5,1,1))
gimbal_uvs = [mcc.project_points(gimbal_positions, ext, *intr) for ext,intr in zip(all_extrinsics, all_intrinsics)]
centroids = gaussian_filter1d(np.mean(gimbal_uvs,axis=2),10,axis=1)

detection_uvs = []
for i,c in tqdm.tqdm(enumerate(camera_names)):
    with h5py.File(f'{vid_dir}/{c}.keypoints.h5','r') as h5:
        uvs = h5['uv'][()][:,use_bodyparts_ix][:,node_order]
        mask = h5['conf'][()][:,use_bodyparts_ix][:,node_order] < 0.25
        uvs[mask] = np.nan
        detection_uvs.append(uvs)

0it [00:00, ?it/s]

In [11]:
readers = [OpenCVReader(f'{vid_dir}/{c}.avi') for c in camera_names]
output_path = f'qc_videos/{vid_dir.split("/")[-1]}.stages.mp4'
all_uvs = np.stack([detection_uvs,triang_uvs,gimbal_uvs])

with imageio.get_writer(
    output_path, pixelformat="yuv420p", fps=30, quality=5
) as writer:
    for i in tqdm.trange(4000):
        base_ims = [reader[i] for reader in readers]
        frame = []
        for uvs,name in zip(all_uvs[:,:,i],['detections','triangulation','gimbal']):
            row = []
            for j,base_im in enumerate(base_ims):
                im = kpms.overlay_keypoints_on_image(base_im.copy(), uvs[j], edges)
                im = kpms.crop_image(im, centroids[j,i], 384)
                im = cv2.putText(im, name, (10, 36), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)
                im = cv2.putText(im, camera_names[j], (10, 18), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)
                row.append(im)
            frame.append(np.hstack(row))
        frame = np.vstack(frame)
        frame = cv2.putText(frame, repr(i), (10, frame.shape[0]-12), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)
        frame = cv2.resize(frame, (1536,768))
        writer.append_data(frame)

  0%|          | 0/4000 [00:00<?, ?it/s]