In [37]:
# %matplotlib widget
import argparse
from pathlib import Path
from typing import List
from dataclasses import dataclass, asdict
from datetime import datetime

import matplotlib
matplotlib.use('Agg')  # Configure backend before importing pyplot
import matplotlib.pyplot as plt
import numpy as np
from threed_utils.multiview_calibration.detection import run_checkerboard_detection, detect_chessboard, plot_chessboard_qc_data, generate_chessboard_objpoints
from threed_utils.multiview_calibration.calibration import calibrate, get_intrinsics
from threed_utils.multiview_calibration.bundle_adjustment import bundle_adjust
from threed_utils.multiview_calibration.viz import plot_residuals, plot_shared_detections
from threed_utils.multiview_calibration.geometry import triangulate
from tqdm import tqdm, trange
import hickle
from threed_utils.io import write_calibration_toml
import cv2
from pipeline_params import CalibrationOptions, DetectionOptions, DetectionRunnerOptions, ProcessingOptions
from movement.io.load_poses import from_numpy
import xarray as xr


In [None]:
def find_video_files(data_dir: Path) -> List[Path]:
    """Find video files in the data directory."""
    video_paths = [
        f for f in data_dir.iterdir() 
        if f.suffix == ".mp4" and "overlay" not in f.stem
    ]
    
    if not video_paths:
        raise ValueError(f"No video files found in {data_dir}")
    
    return video_paths

In [41]:
# Run detection
calibration_options = CalibrationOptions()
detection_runner_options = DetectionRunnerOptions()
detection_options = DetectionOptions()
n_workers = 1
folder = Path("/Users/vigji/Desktop/test_3d/Calibration/20250509/multicam_video_2025-05-09T09_56_51_cropped-v2_20250710121328")


all_calib_uvs, all_img_sizes, video_files_dict = run_checkerboard_detection(folder, 
                                                                        extension=detection_runner_options.video_extension, 
                                                                        overwrite=detection_runner_options.overwrite,
                                                                        detection_options=asdict(detection_options),
                                                                        n_workers=n_workers)

camera_names = list(video_files_dict.keys())
video_paths = list(video_files_dict.values())
print(f"Found cameras: {camera_names}")


def uvs_to_ds(all_calib_uvs, camera_names):
    all_xarrs = []
    for all_calib_pts in all_calib_uvs:
        reshaped_for_mov = np.swapaxes(all_calib_pts, 1, 2)[:, :, :, None]
        mov_xarr = from_numpy(reshaped_for_mov, individual_names=["calibrator"], keypoint_names=np.arange(reshaped_for_mov.shape[2]),
                            confidence_array=np.ones_like(reshaped_for_mov[:, 0, :, :]))
        all_xarrs.append(mov_xarr)
    new_coord_views = xr.DataArray(camera_names, dims="view")
    return xr.concat(all_xarrs, dim=new_coord_views)

uvs_ds = uvs_to_ds(all_calib_uvs, camera_names)

*.mp4
Found 5 video files in /Users/vigji/Desktop/test_3d/Calibration/20250509/multicam_video_2025-05-09T09_56_51_cropped-v2_20250710121328
Found cameras: ['central', 'mirror-bottom', 'mirror-left', 'mirror-right', 'mirror-top']


In [53]:
uvs_ds.position

array([False, False, False, ..., False, False, False], shape=(13510,))

In [75]:
valid_uvs_ds_idxs = (~np.isnan(uvs_ds.position.values[:, :, 0, 0, 0])).sum(axis=0) > 1  #.sum(0) > 1
valid_uvs_ds = uvs_ds.isel(time=valid_uvs_ds_idxs)
valid_uvs_ds.position.shape

(5, 2891, 2, 35, 1)

In [76]:
from threed_utils.anipose.movement_anipose import anipose_triangulate_ds
from threed_utils.io import load_calibration


In [107]:
from threed_utils.arena_utils import load_arena_multiview_ds
arena_json = Path("/Users/vigji/Desktop/test_3d/multicam_video_2025-05-07T10_12_11_20250528-153946.json")
arena_ds = load_arena_multiview_ds(arena_json)

In [108]:
calib_dir = folder 
calib_dir.exists()

all_calibvals = load_calibration(calib_dir)

Got calibration for the following cameras:  ['central', 'mirror-bottom', 'mirror-left', 'mirror-right', 'mirror-top']


(5, 2891, 2, 8, 1)

2891

100%|███████████████████████████████████| 8/8 [00:00<00:00, 44.98it/s]


(1, 1, 8, 3)
(1, 1, 8)
array before numpy
(1, 3, 8, 1)
(1, 8, 1)


In [None]:
calib_toml = folder / "mc_calibration_output_20250710-152443" / "calibration_from_mc.toml"
assert calib_toml.exists()
triang_config_optim = {
    "ransac": True,
    "optim": False,}
checkerboard_triang_ds = anipose_triangulate_ds(views_ds=valid_uvs_ds, calib_toml_path=calib_toml, **triang_config_optim)
# stack arena_ds over time 
arena_triang_ds = anipose_triangulate_ds(views_ds=arena_ds, calib_toml_path=calib_toml, **triang_config_optim)
arena_triang_ds = xr.concat([arena_triang_ds,]*len(valid_uvs_ds.coords["time"]), dim="time")

100%|███████████████████████| 101185/101185 [01:02<00:00, 1624.16it/s]
  dout[bp + '_ncams'] = num_cams[:, bp_num]
[0m
  dout[bp + '_score'] = scores_3d[:, bp_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_error'] = all_errors[:, bp_num]
[0m
  dout[bp + '_ncams'] = num_cams[:, bp_num]
[0m
  dout[bp + '_score'] = scores_3d[:, bp_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_error'] = all_errors[:, bp_num]
[0m
  dout[bp + '_ncams'] = num_cams[:, bp_num]
[0m
  dout[bp + '_score'] = scores_3d[:, bp_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num,

(2891, 1, 35, 3)
(2891, 1, 35)
array before numpy
(2891, 3, 35, 1)
(2891, 35, 1)


100%|███████████████████████| 101185/101185 [01:02<00:00, 1621.62it/s]
  dout[bp + '_ncams'] = num_cams[:, bp_num]
[0m
  dout[bp + '_score'] = scores_3d[:, bp_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_error'] = all_errors[:, bp_num]
[0m
  dout[bp + '_ncams'] = num_cams[:, bp_num]
[0m
  dout[bp + '_score'] = scores_3d[:, bp_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_error'] = all_errors[:, bp_num]
[0m
  dout[bp + '_ncams'] = num_cams[:, bp_num]
[0m
  dout[bp + '_score'] = scores_3d[:, bp_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num, ax_num]
[0m
  dout[bp + '_' + axis] = all_points_3d_adj[:, bp_num,

(2891, 1, 35, 3)
(2891, 1, 35)
array before numpy
(2891, 3, 35, 1)
(2891, 35, 1)


In [103]:
import napari

def view_movement_3d(movement_ds: "xr.Dataset", viewer: napari.Viewer | None = None):
    """
    Open napari viewer with movement data as color-coded points (2D projection).
    
    Parameters
    ----------
    movement_ds : xarray.Dataset
        Movement dataset with dimensions (time, space, keypoints, individuals)
        where space=['x', 'y', 'z'] and individuals can be squeezed
    """
    import xarray as xr
    
    if viewer is None:
        # Create napari viewer
        viewer = napari.Viewer(ndisplay=3)
    
    # Squeeze individuals dimension and get position data
    # Shape: (time, space, keypoints)
    positions = movement_ds.position.squeeze('individuals')
    n_time, n_space, n_keypoints = positions.shape
    
    # Extract x,y coordinates only (drop z)
    x_coords = positions.sel(space='x').values  # (time, keypoints)
    y_coords = positions.sel(space='y').values  # (time, keypoints)
    z_coords = positions.sel(space='z').values  # (time, keypoints)
    
    # Create time and keypoint index arrays
    time_indices = np.repeat(np.arange(n_time), n_keypoints)  # [0,0,0,...,1,1,1,...] 
    keypoint_indices = np.tile(np.arange(n_keypoints), n_time)  # [0,1,2,...,0,1,2,...]
    
    # Flatten coordinate arrays
    x_flat = x_coords.flatten()
    y_flat = y_coords.flatten()
    z_flat = z_coords.flatten()
    
    # Create mask for valid (non-NaN) points
    # valid_mask = ~(np.isnan(x_flat) | np.isnan(y_flat) | np.isnan(z_flat))
    
    # Filter to valid points only - format for napari: (time, y, x)
    points_data = np.column_stack([
        time_indices, #[valid_mask],
        y_flat, #[valid_mask], 
        x_flat, #[valid_mask],
        z_flat
        # z_flat[valid_mask]
    ])
    print(points_data.shape)
    keypoint_ids = keypoint_indices# [valid_mask]
    
    # Add detections as points layer
    viewer.add_points(
        points_data,
        features={"keypoint_id": keypoint_ids},
        face_color="keypoint_id",
        face_colormap="viridis",
        size=5,
        name="2D Keypoints"
    )
    
    print(f"Loaded {len(points_data)} valid 2D keypoints from {n_time} frames with {n_keypoints} keypoints each")
    print("Use the timeline slider to navigate through frames")
    
    return viewer

In [120]:
viewer = napari.Viewer()
view_movement_3d(checkerboard_triang_ds, viewer=viewer)
view_movement_3d(arena_triang_ds, viewer=viewer)

(101185, 4)
Loaded 101185 valid 2D keypoints from 2891 frames with 35 keypoints each
Use the timeline slider to navigate through frames
(23128, 4)
Loaded 23128 valid 2D keypoints from 2891 frames with 8 keypoints each
Use the timeline slider to navigate through frames


Viewer(camera=Camera(center=(0.0, np.float64(-16.744876622441893), np.float64(1076.2605586110178)), zoom=np.float64(3.43146008483652), angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(np.float64(1445.0), np.float64(-3.210827144422126), 0.0, 0.0), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=4, ndisplay=2, order=(0, 1, 2, 3), axis_labels=('0', '1', '2', '3'), rollable=(True, True, True, True), range=(RangeTuple(start=np.float64(0.0), stop=np.float64(2890.0), step=np.float64(1.0)), RangeTuple(start=np.float64(-178.21082714442213), stop=np.float64(179.8146585575794), step=np.float64(1.0)), RangeTuple(start=np.float64(-186.31876974523942), stop=np.float64(159.17363327132372), step=np.float64(1.0)), RangeTuple(start=np.float64(967.4098511716178), stop=np.float64(1138.507386067895), step=np.float64(1.0))), margin_left=(0.0, 0.0, 0.0, 0.0), margin_right=(0.0, 0.0, 0.0, 0.0), point=(np.float64(1445.0), np.fl

In [106]:
arena_ds

In [105]:
triang_ds

In [None]:

# Generate object points
calib_objpoints = generate_chessboard_objpoints(detection_options.board_shape, calibration_options.square_size)

# Run calibration
all_extrinsics, all_intrinsics, calib_poses, spanning_tree = calibrate(
    all_calib_uvs, all_img_sizes, calib_objpoints, n_samples_for_intrinsics=calibration_options.n_samples_for_intrinsics
)

# Run bundle adjustment
adj_extrinsics, adj_intrinsics, adj_calib_poses, use_frames, result = bundle_adjust(
    all_calib_uvs, all_extrinsics, all_intrinsics, calib_objpoints, calib_poses, n_frames=calibration_options.n_frames, ftol=calibration_options.ftol
)