In [None]:
%reload_ext autoreload
%autoreload 2

from pathlib import Path
import pickle
from datetime import datetime
import os

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.animation import FuncAnimation
import tqdm.notebook as tqdm
import scipy.io
import h5py
import torch

In [None]:
def load_mat(filename: str):
    try:
        data = scipy.io.loadmat(filename)
    except NotImplementedError:
        with h5py.File(filename, "r") as file:
            data = {k: file[k][:].T for k in file.keys()}
    return data

def load_tracker_data(base_expdir: Path, y_offset: float = 0.02
                      ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Return (camera2origin, object2origin, T_origin2world, state_timestamps)."""
    pkl_path = base_expdir / "captures" / f"{base_expdir.stem}.pkl"
    assert pkl_path.exists(), pkl_path

    records: list[dict] = []
    with open(pkl_path, "rb") as f:
        while True:
            try:
                records.append(pickle.load(f))
            except EOFError:
                break

    cam2w, obj2w, ori2w, ts = [], [], [], []
    for r in records:
        t = np.mean([r["data"]["camera"][0],
                     r["data"]["object"][0],
                     r["data"]["origin"][0]])
        ts.append(t)
        cam2w.append(r["data"]["camera"][1])
        obj2w.append(r["data"]["object"][1])
        ori2w.append(r["data"]["origin"][1])

    cam2w = np.array(cam2w)
    obj2w = np.array(obj2w)
    ori2w = np.array(ori2w)
    ts    = np.array(ts)

    # mean origin frame via SVD
    U, _, Vt = np.linalg.svd(ori2w[:, :3, :3].sum(0))
    R = U @ Vt
    if np.linalg.det(R) < 0:
        R = U @ np.diag([1, 1, -1]) @ Vt
    t_vec = ori2w[:, :3, 3].mean(0)
    t_vec[1] += y_offset

    T_o2w         = np.eye(4)
    T_o2w[:3, :]  = np.c_[R, t_vec]
    T_w2o         = np.linalg.inv(T_o2w)
    cam2o, obj2o  = T_w2o @ cam2w, T_w2o @ obj2w
    return cam2o, obj2o, T_o2w, ts


def load_hist_data(base_expdir: Path, bank: int, glass_gate: int, offset_s: float
                   ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Return (pt_clouds, timestamps, hists_cropped)."""
    mat_file, = (base_expdir / "out" / "hists").glob("*.mat")
    mat = load_mat(str(mat_file))

    hists = mat["histos"][bank]  # (F,P,B)
    _, _, B = hists.shape

    # pulse FFT
    pulse = scipy.io.loadmat("data/macro_shape.mat")["pulse_shape"].squeeze()
    pulse_fft = np.fft.fft(pulse, n=B)

    h_mask = hists.copy()
    h_mask[..., :glass_gate] = 0
    h_mask[..., -1] = 0

    corr  = np.abs(np.fft.ifft(np.fft.fft(h_mask, axis=-1) * pulse_fft, axis=-1))
    start = np.argmax(corr, axis=-1) - 8
    end   = np.minimum(start + 100, B)

    bins   = np.arange(B)
    idx    = np.clip(np.argmax(h_mask, -1)[..., None] + bins, 0, B-1)
    shifted = np.take_along_axis(h_mask, idx, -1)
    mask    = (bins >= start[..., None]) & (bins < end[..., None])
    hists_cropped = shifted * mask

    pc = np.transpose(mat["ptCloud"][bank], (1, 2, 0)).astype(float)
    pc = np.where(np.abs(pc) < 100, pc, np.nan)
    pc[..., [0, 2]] *= -1

    ts = np.array(mat["timestamps"][bank], dtype=float) * 1e-7
    ts += datetime.strptime(mat_file.stem, "%Y%m%d_%H%M%S").timestamp() - offset_s
    return pc, ts, hists_cropped


def pc_to_world(pt_clouds: np.ndarray, timestamps: np.ndarray,
                camera2origin: np.ndarray, state_timestamps: np.ndarray
                ) -> np.ndarray:
    """Transform point clouds into the origin frame."""
    R_pc2cam = np.array([[-1, 0, 0],
                         [ 0, 0, 1],
                         [ 0, 1, 0]])
    t_pc2cam = np.array([0.065, -0.040, 0.0425]) * -1
    T_pc2cam = np.eye(4)
    T_pc2cam[:3, :3] = R_pc2cam
    T_pc2cam[:3, 3]  = t_pc2cam

    n_frames, n_pts, _ = pt_clouds.shape
    idxs = np.abs(state_timestamps[:, None] - timestamps[None, :]).argmin(0)

    world_pc = np.empty_like(pt_clouds)
    ones = np.ones((n_pts, 1))
    for i in range(n_frames):
        T = camera2origin[idxs[i]] @ T_pc2cam
        pts_h = np.hstack((pt_clouds[i], ones))
        world_pc[i] = (T @ pts_h.T).T[:, :3]
    return world_pc


In [None]:
expdir = Path("logs/2025-05-18/handheld_motion_U_1steps_1p0s_static_motion_7")
expdir = Path("logs/2025-05-18/handheld_motion_U_1steps_5p0s_cam_motion_1")
expdir = Path("logs/2025-05-18/handheld_motion_U_1steps_5p0s_cam_motion_3")
expdir = Path("logs/2025-05-19/gantry_motion_snake_20x20_2xsteps_3ysteps_128xrange_116yrange_0xinit_0yinit_1p0s_obj_motion_2")
assert expdir.exists()

bank = 0
if bank == 1:
    bin_width = 104e-12
    glass_gate = 34
elif bank == 0:
    bin_width = 208e-12
    glass_gate = 17
else:
    raise ValueError("bank must be 0 or 1")

camera2origin, object2origin, T_origin2world, state_timestamps = load_tracker_data(expdir)
point_clouds, timestamps, hists_cropped = load_hist_data(expdir, bank=bank, glass_gate=glass_gate, offset_s=0.0)


# filename = expdir / "captures" / f"{expdir.stem}.pkl"
# assert filename.exists(), filename

# data = []
# with open(filename, "rb") as f:
#     while True:
#         try:
#             data.append(pickle.load(f))
#         except EOFError:
#             break

# # --- extract states --------------------------------
# camera2world, object2world, origin2world, state_timestamps = [], [], [], []
# for d in data:
#     t = np.mean([d["data"]["camera"][0],
#                  d["data"]["object"][0],
#                  d["data"]["origin"][0]])
#     state_timestamps.append(t)
#     camera2world.append(d["data"]["camera"][1])
#     object2world.append(d["data"]["object"][1])
#     origin2world.append(d["data"]["origin"][1])

# camera2world = np.array(camera2world)   # (N, 4, 4)
# object2world = np.array(object2world)   # (N, 4, 4)
# origin2world = np.array(origin2world)   # (N, 4, 4)
# state_timestamps = np.array(state_timestamps)

# camera2world_orig = camera2world.copy()
# object2world_orig = object2world.copy()

# # --- average origin (base) frame via SVD --------------------------------
# M = origin2world[:, :3, :3].sum(axis=0)
# U, _, Vt = np.linalg.svd(M)
# R_mean = U @ Vt
# if np.linalg.det(R_mean) < 0:
#     R_mean = U @ np.diag([1, 1, -1]) @ Vt

# t_mean = origin2world[:, :3, 3].mean(axis=0)
# t_mean[1] += 0.02 # offset from wall

# T_origin2world = np.eye(4)
# T_origin2world[:3, :] = np.c_[R_mean, t_mean]
# T_world2origin = np.linalg.inv(T_origin2world)

# # --- poses in origin (base) frame --------------------------------
# camera2origin = T_world2origin @ camera2world   # (N, 4, 4)
# object2origin = T_world2origin @ object2world   # (N, 4, 4)

In [None]:
filenames = list((expdir / "out" / "hists").glob("*.mat"))
assert len(filenames) == 1, "There should be only one .mat file in the out/hists directory."
filename = filenames[0]
assert filename.exists(), filename

def load_mat(filename: str):
    try:
        data = scipy.io.loadmat(filename)
    except NotImplementedError:
        with h5py.File(filename, "r") as file:
            data = {k: file[k][:].T for k in file.keys()}
    return data

def peak_detection(meas, offset: int = 0):
    """Note the returned value corresponds to the bin _after_ the max value to crop from."""
    pulse_shape = scipy.io.loadmat("data/macro_shape.mat")["pulse_shape"]
    # Use the pulse shape to filter the signal
    pulse_fft = np.fft.fft(pulse_shape, axis=-1)
    meas_fft = np.fft.fft(meas, axis=-1)
    corr = np.abs(np.fft.ifft(pulse_fft * meas_fft, axis=-1))
    tof = np.argmax(corr, axis=-1)[0]
    return tof - 8 + offset

bank = 0
if bank == 1:
    bin_width = 104e-12
    glass_gate = 34
elif bank == 0:
    bin_width = 208e-12
    glass_gate = 17
else:
    raise ValueError("bank must be 0 or 1")

mat = load_mat(filename)

hists = mat["histos"][bank]
hists_cropped = []
F, P, B = hists.shape

# precompute pulse FFT
pulse_shape = scipy.io.loadmat("data/macro_shape.mat")["pulse_shape"].squeeze()
pulse_fft   = np.fft.fft(pulse_shape, n=B)

# mask reflections
hists_mask = hists.copy()
hists_mask[..., :glass_gate] = 0
hists_mask[..., -1]       = 0

# cross-correlation to find start gates
hists_fft = np.fft.fft(hists_mask, axis=-1)
corr      = np.abs(np.fft.ifft(hists_fft * pulse_fft[None, None, :], axis=-1))
start     = np.argmax(corr, axis=-1) - 8
end       = np.minimum(start + 100, B)

# 1B time‐of‐flight
bin0     = np.argmax(hists_mask, axis=-1)

# dynamic cropping + gating
bins       = np.arange(B)
idx        = bin0[..., None] + bins[None]
clipped    = np.clip(idx, 0, B-1)
shifted    = np.take_along_axis(hists_mask, clipped, axis=-1)
mask_end   = idx < B
mask_gate  = (bins[None] >= start[..., None]) & (bins[None] < end[..., None])

hists_cropped = shifted * mask_end * mask_gate  # shape (F, P, B)

pt_clouds = np.transpose(mat["ptCloud"][bank], (1, 2, 0))
pt_clouds = np.where(np.abs(pt_clouds) < 100, pt_clouds, np.nan)  # rm outliers
pt_clouds[..., [0, 2]] *= -1  # TODO: Why?

offset = 1.5
timestamps = np.array(mat["timestamps"][bank], dtype=float) * 1e-7
timestamps += datetime.strptime(filename.stem, "%Y%m%d_%H%M%S").timestamp() - offset

# define PC to cam rotation
R_pc2cam = np.array([[-1, 0, 0],
                     [ 0, 0, 1],
                     [ 0, 1, 0]])
t_pc2cam = np.array([0.065, -0.040, 0.0425]) * -1 # calibration between camera and PC
T_pc2cam = np.eye(4)
T_pc2cam[:3, :3] = R_pc2cam 
T_pc2cam[:3, 3] = t_pc2cam

# transform point clouds
# assume pt_clouds: (n_frames, n_pts, 3), timestamps: (n_frames,)
n_frames, n_pts, _ = pt_clouds.shape
idxs = np.abs(state_timestamps[:,None] - timestamps[None,:]).argmin(axis=0)
world_pt_clouds = np.empty((n_frames, n_pts, 3))
for i in range(n_frames):
    T = camera2origin[idxs[i]] @ T_pc2cam
    pts_h = np.hstack([pt_clouds[i], np.ones((n_pts,1))])
    world_pt_clouds[i] = (T @ pts_h.T).T[:,:3]

# Set z = 0
# world_pt_clouds[:, :, 2] = 0
pt_clouds = world_pt_clouds

fig = plt.figure(figsize=(12, 8))

ax1 = fig.add_subplot(1, 3, 1)
ax1.plot(world_pt_clouds[:, :, 0], world_pt_clouds[:, :, 1], "k.", markersize=1)
ax1.plot(camera2origin[:, 0, 3], camera2origin[:, 1, 3], "r.")
ax1.set_title("X-Y")
ax1.invert_xaxis()

ax2 = fig.add_subplot(1, 3, 2)
ax2.plot(world_pt_clouds[:, :, 0], world_pt_clouds[:, :, 2], "k.", markersize=1)
ax2.plot(camera2origin[:, 0, 3], camera2origin[:, 2, 3], "r.")
ax2.set_title("X-Z")
ax2.invert_xaxis()

ax3 = fig.add_subplot(1, 3, 3)
ax3.plot(world_pt_clouds[:, :, 1], world_pt_clouds[:, :, 2], "k.", markersize=1)
ax3.plot(camera2origin[:, 1, 3], camera2origin[:, 2, 3], "r.")
ax3.set_title("Y-Z")
ax3.set_ylim(-0.2, 0.2)

plt.tight_layout()
ax1.set_aspect("equal")
ax2.set_aspect("equal")
ax3.set_aspect("equal")
plt.show()


In [None]:
plt.figure(figsize=(10, 10))
idxs = np.abs(state_timestamps[:,None] - timestamps[None,:]).argmin(axis=0)
plt.plot(camera2origin[:, 0, 3], camera2origin[:, 1, 3], '.r')
plt.plot(camera2origin[idxs, 0, 3], camera2origin[idxs, 1, 3], '.g')
plt.gca().invert_xaxis()

In [None]:
fig = plt.figure(figsize=(12, 8))

ax1 = fig.add_subplot(1, 3, 1)
ax1.plot(world_pt_clouds[:, :, 0], world_pt_clouds[:, :, 1], "k.", markersize=1)
ax1.plot(camera2origin[:, 0, 3], camera2origin[:, 1, 3], "r.")
ax1.set_title("X-Y")
ax1.invert_xaxis()

ax2 = fig.add_subplot(1, 3, 2)
ax2.plot(world_pt_clouds[:, :, 0], world_pt_clouds[:, :, 2], "k.", markersize=1)
ax2.plot(camera2origin[:, 0, 3], camera2origin[:, 2, 3], "r.")
ax2.set_title("X-Z")
ax2.invert_xaxis()

ax3 = fig.add_subplot(1, 3, 3)
ax3.plot(world_pt_clouds[:, :, 1], world_pt_clouds[:, :, 2], "k.", markersize=1)
ax3.plot(camera2origin[:, 1, 3], camera2origin[:, 2, 3], "r.")
ax3.set_title("Y-Z")

plt.tight_layout()
ax1.set_aspect("equal")
ax2.set_aspect("equal")
ax3.set_aspect("equal")
plt.show()


In [None]:
# ---------- plotting ----------
entities = [camera2origin, object2origin]
colors   = ['.r', '.g']
labels   = ['Camera', 'Object']
trajs = [e[:, :3, 3] for e in entities]
mins, maxs = np.vstack(trajs).min(0), np.vstack(trajs).max(0)

fig = plt.figure(figsize=(12, 10))
ax_xy = fig.add_subplot(2, 2, 1)
ax_xz = fig.add_subplot(2, 2, 2)
ax_yz = fig.add_subplot(2, 2, 3)
ax_3d = fig.add_subplot(2, 2, 4, projection='3d')

for traj, c, lbl in zip(trajs, colors, labels):
    x, y, z = traj.T
    ax_xy.plot(x, y, c, label=lbl)
    ax_xz.plot(x, z, c)
    ax_yz.plot(y, z, c)
    ax_3d.plot(x, y, z, c, label=lbl)

# --- coloured & labelled axes ---
cols = {'X': 'r', 'Y': 'g', 'Z': 'b'}
L    = 0.1 * np.linalg.norm(maxs - mins)
v    = np.eye(3) * L

# XY
ax_xy.plot([0, v[0, 0]], [0, 0], c=cols['X'], lw=2)
ax_xy.plot([0, 0], [0, v[1, 1]], c=cols['Y'], lw=2)
ax_xy.text(v[0, 0], 0, 'X', color=cols['X'], fontsize=8, ha='left', va='bottom')
ax_xy.text(0, v[1, 1], 'Y', color=cols['Y'], fontsize=8, ha='left', va='bottom')
ax_xy.invert_xaxis()

# XZ
ax_xz.plot([0, v[0, 0]], [0, 0], c=cols['X'], lw=2)
ax_xz.plot([0, 0], [0, v[2, 2]], c=cols['Z'], lw=2)
ax_xz.text(v[0, 0], 0, 'X', color=cols['X'], fontsize=8, ha='left', va='bottom')
ax_xz.text(0, v[2, 2], 'Z', color=cols['Z'], fontsize=8, ha='left', va='bottom')
ax_xz.invert_xaxis()

# YZ
ax_yz.plot([0, 0], [0, v[1, 1]], c=cols['Y'], lw=2)
ax_yz.plot([0, v[2, 2]], [0, 0], c=cols['Z'], lw=2)
ax_yz.text(0, v[1, 1], 'Y', color=cols['Y'], fontsize=8, ha='left', va='bottom')
ax_yz.text(v[2, 2], 0, 'Z', color=cols['Z'], fontsize=8, ha='left', va='bottom')

# 3‑D
ax_3d.plot([0, v[0, 0]], [0, 0], [0, 0], c=cols['X'], lw=2)
ax_3d.plot([0, 0], [0, v[1, 1]], [0, 0], c=cols['Y'], lw=2)
ax_3d.plot([0, 0], [0, 0], [0, v[2, 2]], c=cols['Z'], lw=2)
ax_3d.text(v[0, 0], 0, 0, 'X', color=cols['X'], fontsize=8)
ax_3d.text(0, v[1, 1], 0, 'Y', color=cols['Y'], fontsize=8)
ax_3d.text(0, 0, v[2, 2], 'Z', color=cols['Z'], fontsize=8)

for ax in (ax_xy, ax_xz, ax_yz):
    ax.set_aspect('equal')
ax_xy.set_xlabel('X'); ax_xy.set_ylabel('Y')
ax_xz.set_xlabel('X'); ax_xz.set_ylabel('Z')
ax_yz.set_xlabel('Y'); ax_yz.set_ylabel('Z')
ax_xy.legend(loc='upper right')
ax_3d.set_box_aspect(maxs - mins)
ax_3d.set_xlabel('X'); ax_3d.set_ylabel('Y'); ax_3d.set_zlabel('Z')

plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import plotly.graph_objects as go

# assume trajs, labels, mins, maxs, and world_pt_clouds are defined

fig = go.Figure()

# plot trajectories
for traj, lbl in zip(trajs, labels):
    x, y, z = traj.T
    fig.add_trace(go.Scatter3d(
        x=x, y=y, z=z,
        mode='lines',
        name=lbl
    ))

# plot point-cloud points
pts_all = world_pt_clouds.reshape(-1, 3)
fig.add_trace(go.Scatter3d(
    x=pts_all[:,0], y=pts_all[:,1], z=pts_all[:,2],
    mode='markers',
    marker=dict(size=1),
    name='points'
))

# plot axes
L = 0.1 * np.linalg.norm(maxs - mins)
axes = np.eye(3) * L
axes_labels = ['X', 'Y', 'Z']
axes_colors = ['red', 'green', 'blue']

for axis, label, color in zip(axes, axes_labels, axes_colors):
    x, y, z = axis
    fig.add_trace(go.Scatter3d(
        x=[0, x], y=[0, y], z=[0, z],
        mode='lines+text',
        line=dict(color=color, width=5),
        text=[None, label],
        textposition='top center',
        showlegend=False
    ))

fig.update_layout(
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        camera=dict(
            up=dict(x=0, y=1, z=0),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=-1, y=1, z=-1)
        ),
        aspectmode='data',
    ),
    margin=dict(l=0, r=0, t=0, b=0),
)

fig.show()


In [None]:
fig, axes = plt.subplots(3, 2, figsize=(10, 6), sharex=True)
titles = ['Camera', 'Object']
axis_labels = ['X', 'Y', 'Z']
colors = ['r', 'g', 'b']
entities = [camera2origin, object2origin]
trajs = [e[:, :3, 3] for e in entities]

for col, traj in enumerate(trajs):
    for row, (data, label, color) in enumerate(zip(traj.T, axis_labels, colors)):
        axes[row, col].plot(data, color)
        axes[row, col].set_ylabel(label)
        if row == 0:
            axes[row, col].set_title(titles[col])

# entities = [camera_states_new, object_states_new]
# trajs = [e[:, :3, 3] for e in entities]
# titles = ['Camera New', 'Object New']
# colors = ['--r', '--g', '--b']
# for col, traj in enumerate(trajs):
#     for row, (data, label, color) in enumerate(zip(traj.T, axis_labels, colors)):
#         axes[row, col].plot(data, color)
#         axes[row, col].set_ylabel(label)

axes[2, 0].set_xlabel('Time step')
axes[2, 1].set_xlabel('Time step')
plt.tight_layout()
plt.show()


In [None]:
def backproject(
    voxel_grid: np.ndarray,
    pt_clouds: np.ndarray,
    hists: np.ndarray,
    bin_width: float,
) -> np.ndarray:
    C = 3E8
    thresh = bin_width * C
    factor = bin_width * C / 2
    num_bins = hists.shape[1]

    # Precompute cumulative histograms for all rows to avoid redundant computation.
    cum_hists = np.cumsum(hists, axis=1)

    volume = np.zeros((len(voxel_grid), 1))
    for i, cur_pixel in enumerate(pt_clouds):
        # Compute voxel distances for the current pixel.
        dists = np.linalg.norm(voxel_grid - cur_pixel, axis=1)
        # Retrieve the precomputed cumulative histogram for this pixel.
        cum = cum_hists[i, :]

        # Determine the lower and upper bin indices for each voxel.
        lower = np.clip(
            np.ceil((dists - thresh) / factor).astype(int), 0, num_bins - 1
        )
        upper = np.clip(
            np.floor((dists + thresh) / factor).astype(int), 0, num_bins - 1
        )

        # Compute the range sum using the cumulative histogram.
        sums = cum[upper] - np.where(lower > 0, cum[lower - 1], 0)
        volume += sums.reshape(-1, 1)

    return volume

def filter_volume(volume: np.ndarray, num_x, num_y) -> np.ndarray:
    volume_unpadded = 2 * volume[:, :, 1:-1] - volume[:, :, :-2] - volume[:, :, 2:]
    zero_pad = np.zeros((num_x, num_y, 1))
    volume_padded = np.concatenate([zero_pad, volume_unpadded, zero_pad], axis=-1)
    return volume_padded

def create_voxel_grid(x_range, y_range, z_range, x_res=None, y_res=None, z_res=None, num_x=None, num_y=None, num_z=None):
    if x_res is not None and y_res is not None and z_res is not None:
        num_x = int((x_range[1] - x_range[0]) / x_res)
        num_y = int((y_range[1] - y_range[0]) / y_res)
        num_z = int((z_range[1] - z_range[0]) / z_res)
    elif num_x is not None and num_y is not None and num_z is not None:
        x_res = (x_range[1] - x_range[0]) / num_x
        y_res = (y_range[1] - y_range[0]) / num_y
        z_res = (z_range[1] - z_range[0]) / num_z
    else:
        raise ValueError("Either resolutions (x_res, y_res, z_res) or numbers (num_x, num_y, num_z) must be provided.")

    x = np.linspace(x_range[0], x_range[1], num_x, endpoint=False)
    y = np.linspace(y_range[0], y_range[1], num_y, endpoint=False)
    z = np.linspace(z_range[0], z_range[1], num_z, endpoint=False)
    xv, yv, zv = np.meshgrid(x, y, z, indexing='ij')
    voxel_grid = np.stack([xv.ravel(), yv.ravel(), zv.ravel()], axis=-1)
    return voxel_grid, num_x, num_y, num_z, x_res, y_res, z_res

x_range = [-1, 1]
y_range = [0.0, 1.6]
z_range = [-1.5, -0.5]
num_x = 50
num_y = 50
num_z = 25
voxel_grid, num_x, num_y, num_z, x_res, y_res, z_res = create_voxel_grid(
    x_range=x_range,
    y_range=y_range,
    z_range=z_range,
    num_x=num_x,
    num_y=num_y,
    num_z=num_z,
)

hists_to_use = hists_cropped.copy()
pt_clouds_to_use = pt_clouds.copy()
timestamps_to_use = timestamps.copy()
camera2origin_to_use = camera2origin.copy()
object2origin_to_use = object2origin.copy()

step_size = 20
# If step_size > 1, average the histograms over the step size
if step_size > 1:
    hists_to_use = np.array([
        np.mean(hists_to_use[i:i+step_size], axis=0)
        for i in range(0, len(hists_to_use), step_size)
    ])
    pt_clouds_to_use = np.array([
        np.mean(pt_clouds_to_use[i:i+step_size], axis=0)
        for i in range(0, len(pt_clouds_to_use), step_size)
    ])
    timestamps_to_use = np.array([
        np.mean(timestamps_to_use[i:i+step_size], axis=0)
        for i in range(0, len(timestamps_to_use), step_size)
    ])
    camera2origin_to_use = camera2origin_to_use[::step_size]
    object2origin_to_use = object2origin_to_use[::step_size]

bp_volumes = []
for i in tqdm.tqdm(range(len(pt_clouds_to_use)), desc="Backprojecting"):
    bp_volume = backproject(
        voxel_grid=voxel_grid,
        pt_clouds=pt_clouds_to_use[i],
        hists=hists_to_use[i],
        bin_width=bin_width,
    ).reshape(num_x, num_y, num_z)
    bp_volumes.append(bp_volume)

In [None]:
def plot_axis(
    volume: np.ndarray, axis: str, *, xlim: tuple[float, float], ylim: tuple[float, float], zlim: tuple[float, float], xres: float, yres: float, zres: float, num_x: int, num_y: int, num_z: int,
    points: list[np.ndarray] | None = None, gt: list[np.ndarray] | None = None,
    idx: int | None = None, **kwargs
) -> tuple [np.ndarray, float]:
    img: np.ndarray
    xnum, ynum = None, None
    val = None
    assert axis in ["x", "y", "z"], f"Invalid axis {axis}"
    if axis == "x":
        img = volume[idx, :, :].T if idx is not None else volume
        xlim, ylim = zlim, ylim
        xnum, ynum = num_z, num_y
        val = zlim[0] + (idx or 1) * zres
        xidx, yidx = 2, 1
    elif axis == "y":
        img = volume[:, idx, :].T if idx is not None else volume
        xlim, ylim = zlim, xlim
        xnum, ynum = num_z, num_x
        val = xlim[0] + (idx or 1) * xres
        xidx, yidx = 2, 0
    elif axis == "z":
        img = volume[:, :, idx].T if idx is not None else volume
        xlim, ylim = xlim, ylim
        xnum, ynum = num_x, num_y
        val = zlim[0] + (idx or 1) * zres
        xidx, yidx = 0, 1

    plt.imshow(img, **kwargs)

    xticks = np.round(np.linspace(0, xnum - 1, 5), 2)
    xlabels = np.round(np.linspace(xlim[0], xlim[1], 5), 2)
    plt.xticks(xticks, xlabels)

    yticks = np.round(np.linspace(0, ynum - 1, 5), 2)
    ylabels = np.round(np.linspace(ylim[0], ylim[1], 5), 2)
    plt.yticks(yticks, ylabels)

    points = [] if points is None else points
    for point in points:
        x = (point[xidx] - xlim[0]) / (xlim[1] - xlim[0]) * (num_x - 1)
        y = (point[yidx] - ylim[0]) / (ylim[1] - ylim[0]) * (num_y - 1)
        plt.plot(x, y, "og", markersize=10)

    gt = [] if gt is None else gt
    for gt_point in gt:
        x = (gt_point[xidx] - xlim[0]) / (xlim[1] - xlim[0]) * (num_x - 1)
        y = (gt_point[yidx] - ylim[0]) / (ylim[1] - ylim[0]) * (num_y - 1)
        plt.plot(x, y, "or", markersize=10)

    plt.gca().invert_xaxis()
    plt.gca().invert_yaxis()

    return img, val

def plot_volume_slices(
    volume: np.ndarray,
    axis: str,
    title: str,
    *,
    num_cols: int = 5,
    **kwargs,
):
    kwargs.setdefault("cmap", "hot")
    kwargs.setdefault("norm", mcolors.PowerNorm(gamma=2))

    assert axis in ["x", "y", "z"], f"Invalid axis {axis}"
    num_plots = volume.shape["xyz".index(axis)]

    num_rows = np.ceil(num_plots / num_cols).astype(int)
    plt.figure(figsize=(3 * num_cols, 3 * num_rows))

    max_value = np.max(volume)
    volume = np.copy(volume) / np.max(volume) if max_value > 0 else volume
    for i in range(num_plots):
        plt.subplot(num_rows, num_cols, i + 1)
        img, val = plot_axis(volume, axis, idx=i, **kwargs)
        plt.title(f"{axis} = {val:.2f} - Max: {np.max(img):.2f}")
        plt.clim([0, 1])

    plt.suptitle(title)

    # Add colorbar
    plt.colorbar()

def plot_volume_projection(
    volume: np.ndarray,
    title: str,
    gamma: int = 2,
    project_fn = np.max,
    fig: plt.Figure | None = None,
    signal: np.ndarray | None = None,
    **kwargs,
):
    kwargs.setdefault("cmap", "hot")
    kwargs.setdefault("norm", mcolors.PowerNorm(gamma=gamma))

    # normalize
    x_slice = project_fn(volume, axis=0)
    # x_slice = np.interp(x_slice, (x_slice.min(), x_slice.max()), (0, 1))
    y_slice = project_fn(volume, axis=1)
    # y_slice = np.interp(y_slice, (y_slice.min(), y_slice.max()), (0, 1))
    z_slice = project_fn(volume, axis=2)
    # z_slice = np.interp(z_slice, (z_slice.min(), z_slice.max()), (0, 1))

    num = 3 if signal is None else 4
    if fig is None:
        size = 4
        fig = plt.figure(figsize=(size * num, size))

    # Y-Z
    plt.subplot(1, num, 1)
    plot_axis(x_slice, "x", **kwargs)
    plt.title("Y-Z Projection")
    plt.xlabel("Z (m)")
    plt.ylabel("Y (m)")
    plt.gca().invert_xaxis()

    # X-Z
    plt.subplot(1, num, 2)
    plot_axis(y_slice, "y", **kwargs)
    plt.title("X-Z Projection")
    plt.ylabel("X (m)")
    plt.xlabel("Z (m)")
    plt.gca().invert_xaxis()
    plt.gca().invert_yaxis()

    # X-Y
    plt.subplot(1, num, 3)
    plot_axis(z_slice.T, "z", **kwargs)
    plt.title("X-Y Projection")
    plt.ylabel("Y (m)")
    plt.xlabel("X (m)")

    # Signal
    if signal is not None:
        plt.subplot(1, num, 4)
        plt.imshow(signal, cmap="hot", aspect="auto")
        plt.title("Signal")

    plt.suptitle(title)

    plt.colorbar()

def get_best_point(volume: np.ndarray, x_range: tuple[float, float], y_range: tuple[float, float], z_range: tuple[float, float], x_res: float, y_res: float, z_res: float) -> np.ndarray:
    best_point = np.unravel_index(np.argmax(volume), volume.shape)
    best_point = np.array([best_point[0] * x_res + x_range[0], best_point[1] * y_res + y_range[0], best_point[2] * z_res + z_range[0]])
    return best_point

summed_volume = np.sum(bp_volumes, axis=0)
best_summed_point = get_best_point(summed_volume, x_range, y_range, z_range, x_res, y_res, z_res)

filtered_summed_volume = filter_volume(summed_volume, num_x, num_y)
# filtered_summed_volume = filter_volume(filtered_summed_volume, num_x, num_y)
best_filtered_point = get_best_point(filtered_summed_volume, x_range, y_range, z_range, x_res, y_res, z_res)

plot_volume_projection(
    summed_volume,
    title="Summed Volume Projection",
    signal=hists_cropped.mean(axis=0),
    xlim=x_range,
    ylim=y_range,
    zlim=z_range,
    xres=x_res,
    yres=y_res,
    zres=z_res,
    num_x=num_x,
    num_y=num_y,
    num_z=num_z,
    # points=np.mean(object2origin_to_use[:, :3, 3], axis=0, keepdims=True),
    # gt=[best_summed_point]
)
plot_volume_projection(
    filtered_summed_volume,
    title="Filtered Summed Volume Projection",
    signal=hists_cropped.mean(axis=0),
    xlim=x_range,
    ylim=y_range,
    zlim=z_range,
    xres=x_res,
    yres=y_res,
    zres=z_res,
    num_x=num_x,
    num_y=num_y,
    num_z=num_z,
    # points=np.mean(object2origin_to_use[:, :3, 3], axis=0, keepdims=True),
    # gt=[best_filtered_point]
)

In [None]:
pbar = tqdm.tqdm(total=len(bp_volumes), desc="Plotting", leave=False)

# Create video of volume projections
fig = plt.figure(figsize=(16,4))
def update(frame):
    pbar.update(1)
    plt.clf()
    current_volume = bp_volumes[frame]

    plot_volume_projection(
        current_volume,
        f"Volume Projection - Frame {frame}",
        xlim=x_range,
        ylim=y_range,
        zlim=z_range,
        xres=x_res,
        yres=y_res,
        zres=z_res,
        num_x=num_x,
        num_y=num_y,
        num_z=num_z,
        # points=[object2origin_to_use[frame, :3, 3]],
        signal=hists_to_use[frame],
        fig=fig,
    )
    return fig,

anim = FuncAnimation(fig, update, frames=len(bp_volumes), interval=100, blit=True)
from IPython.display import HTML
HTML(anim.to_jshtml())