In [None]:
%reload_ext autoreload
%autoreload 2

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go

from cc_hardware.utils.file_handlers import PklReader
from cc_hardware.drivers.spads.pkl import PklSPADSensorConfig, PklSPADSensor

In [None]:
pkl_path = Path("../../../logs/20240922_192338/data.pkl")
pkl_path = Path("../../../logs/2025-05-25/10-38-07/data.pkl")
pkl_path = Path("../../../logs/2025-05-25/11-52-34/data.pkl")
pkl_path = Path("../../../logs/2025-05-25/11-54-54/data.pkl")
pkl_path = Path("../../../logs/2025-05-25/13-49-54/data.pkl")
pkl_path = Path("../../../logs/2025-05-25/13-58-09/data.pkl")
pkl_path = Path("../../../logs/2025-05-25/13-59-08/data.pkl")
# pkl_path = Path("../../../logs/2025-05-25/14-01-00/data.pkl")
assert pkl_path.exists()

sensor = PklSPADSensor(PklSPADSensorConfig(pkl_path=pkl_path, merge=False, loop=True, key="histograms", resolution=(8, 8)))
pkl_reader = PklReader(pkl_path)
height, width = sensor.resolution
print(f"Resolution: {height} x {width}")

index = 10

In [None]:
histograms = sensor.accumulate(index=index).reshape(height, width, -1)

ylim = None # 1e4
min_bin = 0
max_bin = min(histograms.shape[-1] - 1, 60)

def plot_histogram(i: int, j: int):
    plt.subplot(height, width, i * height + j + 1)
    plt.plot(histograms[i, j])
    # plt.title(f"i={i}, j={j}")
    # plt.yscale("log")

    plt.xlabel("")
    plt.xticks([])
    plt.yticks([])
    plt.ylabel("")

    if ylim is not None:
        plt.ylim(None, ylim)
    if min_bin is not None:
        plt.xlim(min_bin, None)
    if max_bin is not None:
        plt.xlim(None, max_bin)

plt.figure(figsize=(height * 2, width * 2), dpi=400)
for i in range(height):
    for j in range(width):
        plot_histogram(i, j)
plt.tight_layout()

In [None]:
C = 3e8  # speed of light in m/s

def extract_point_cloud(
    hists: np.ndarray,
    *,
    dist_mm: np.ndarray | None = None,
    fov_x: float,
    fov_y: float,
    bin_resolution: float,
    subsample: int = 1,
    N: int = 1,
    window: int = 10,
    start_bin: int = 0,
    threshold: float = 0,
    linear_interp: bool = False
) -> np.ndarray:
    H, W, B = hists.shape
    half = window // 2

    # 1) compute weighted mean time-of-flight per pixel (in mm)
    if dist_mm is None:
        dist_mm = np.zeros((H, W), dtype=float)
        for i in range(H):
            for j in range(W):
                hist = hists[i, j]
                idx = hist.argmax()
                start = max(0, idx - half)
                end = min(B, idx + half + 1)
                bins = np.arange(start, end) + start_bin
                w = hist[start:end].astype(float)
                if w.sum() > threshold:
                    w /= w.sum()
                    t_mean = w @ (bins * bin_resolution * subsample)
                    dist_mm[i, j] = (C * t_mean / 2) * 1e3

    # 2) angular resolution per pixel
    px_x = np.deg2rad(fov_x) / W
    px_y = np.deg2rad(fov_y) / H

    # 3) build point cloud
    pts = []
    for i in range(H):
        for j in range(W):
            for u in range(N):
                for v in range(N):
                    x_sub = j + (v + 0.5) / N
                    y_sub = i + (u + 0.5) / N

                    if linear_interp:
                        x0 = int(np.floor(x_sub)); y0 = int(np.floor(y_sub))
                        x1 = min(x0 + 1, W - 1);   y1 = min(y0 + 1, H - 1)
                        dx = x_sub - x0;           dy = y_sub - y0

                        d00 = dist_mm[y0, x0]
                        d10 = dist_mm[y0, x1]
                        d01 = dist_mm[y1, x0]
                        d11 = dist_mm[y1, x1]

                        d = (
                            d00 * (1 - dx) * (1 - dy) +
                            d10 * dx       * (1 - dy) +
                            d01 * (1 - dx) * dy       +
                            d11 * dx       * dy
                        )
                    else:
                        d = dist_mm[i, j]

                    d = max(0.0, d)

                    angle_x = y_sub * px_y - np.deg2rad(fov_y) / 2 - np.pi/2
                    angle_y = x_sub * px_x - np.deg2rad(fov_x) / 2

                    x = d * np.cos(angle_x) / 1e3
                    y = d * np.sin(angle_y) / 1e3
                    z = d / 1e3
                    pts.append([x, y, z])

    return np.array(pts)

distances = pkl_reader.load(index=index)["objects"].reshape(height, width)
histograms = sensor.accumulate(index=index).reshape(height, width, -1)
pt_cloud = extract_point_cloud(histograms, fov_x=45, fov_y=45, bin_resolution=250e-12, N=10, linear_interp=True, window=10, start_bin=1, subsample=7)
pt_cloud_from_distances = extract_point_cloud(histograms, fov_x=45, fov_y=45, bin_resolution=250e-12, N=10, linear_interp=True, window=10, start_bin=1, dist_mm=distances, subsample=7)

# Flatten the point cloud for Plotly
x = pt_cloud[..., 0].flatten()
y = pt_cloud[..., 1].flatten()
z = pt_cloud[..., 2].flatten()

x_from_distances = pt_cloud_from_distances[..., 0].flatten()
y_from_distances = pt_cloud_from_distances[..., 1].flatten()
z_from_distances = pt_cloud_from_distances[..., 2].flatten()

fig = go.Figure(data=
    [
        # go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(color='blue', size=3)),
        go.Scatter3d(x=x_from_distances, y=y_from_distances, z=z_from_distances, mode='markers', marker=dict(color='red', size=3))
    ]
)

fig.update_layout(
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        aspectmode='data'  # or 'cube' or 'manual' with 'aspectratio'
    )
)

fig.show()