In [None]:
import cv2
import numpy as np
import uuid

VIDEO_PATH = r"/Users/carolinechueh/Desktop/holoray-ui/sample.mp4"

# LK optical flow params
lk_params = dict(
    winSize=(7, 7),
    maxLevel=3,
    criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 20, 0.03),
)

# For feature seeding along the line
feature_params = dict(
    maxCorners=80,
    qualityLevel=0.01,
    minDistance=3,
    blockSize=7,
)

# id -> dict with curve_prev (Nx2), features_prev (Mx1x2), color, health
annotations = {}
current_frame = None
prev_gray = None

# Stroke drawing state
drawing = False
stroke_points = []  # raw mouse points for current stroke


def resample_curve(points, step=5):
    """Resample freehand stroke to evenly spaced curve points."""
    if len(points) < 2:
        return np.array([], dtype=np.float32)
    pts = np.array(points, dtype=np.float32)
    diffs = np.diff(pts, axis=0)
    seg_lens = np.linalg.norm(diffs, axis=1)
    total_len = seg_lens.sum()
    if total_len < 1e-3:
        return pts
    num_samples = max(10, int(total_len / step))
    cumlen = np.concatenate([[0], np.cumsum(seg_lens)])
    t_vals = np.linspace(0, total_len, num_samples)
    resampled = []
    for t in t_vals:
        idx = np.searchsorted(cumlen, t) - 1
        idx = max(0, min(idx, len(diffs) - 1))
        t0 = cumlen[idx]
        t1 = cumlen[idx + 1]
        alpha = 0 if t1 == t0 else (t - t0) / (t1 - t0)
        p = pts[idx] * (1 - alpha) + pts[idx + 1] * alpha
        resampled.append(p)
    return np.array(resampled, dtype=np.float32)


def make_band_mask(frame_shape, curve, band_width=6):
    """Create a band mask around the curve for seeding features."""
    mask = np.zeros(frame_shape[:2], dtype=np.uint8)
    curve_int = curve.astype(np.int32).reshape(-1, 1, 2)
    cv2.polylines(mask, [curve_int], isClosed=False, color=255, thickness=band_width)
    return mask


def on_mouse(event, x, y, flags, param):
    global drawing, stroke_points, annotations, current_frame

    if event == cv2.EVENT_LBUTTONDOWN and current_frame is not None:
        drawing = True
        stroke_points = [(x, y)]

    elif event == cv2.EVENT_MOUSEMOVE and drawing and current_frame is not None:
        stroke_points.append((x, y))

    elif event == cv2.EVENT_LBUTTONUP and drawing and current_frame is not None:
        drawing = False
        stroke_points.append((x, y))

        # 1) resample stroke to a cleaner curve
        curve = resample_curve(stroke_points, step=5)
        if curve.shape[0] < 2:
            return

        # 2) seed features in a band around the curve
        mask = make_band_mask(current_frame.shape, curve, band_width=8)
        gray = cv2.cvtColor(current_frame, cv2.COLOR_BGR2GRAY)
        pts = cv2.goodFeaturesToTrack(gray, mask=mask, **feature_params)

        if pts is None or len(pts) == 0:
            print("No features found along stroke; annotation skipped")
            return

        ann_id = str(uuid.uuid4())
        annotations[ann_id] = {
            "id": ann_id,
            "curve_prev": curve,                   # Nx2
            "features_prev": pts.astype(np.float32),  # Mx1x2
            "color": tuple(np.random.randint(0, 255, size=3).tolist()),
            "health": 1.0,
        }
        print(f"Created annotation {ann_id} with {len(pts)} features along line")


def main():
    global current_frame, prev_gray, annotations, drawing, stroke_points

    cap = cv2.VideoCapture(VIDEO_PATH)
    if not cap.isOpened():
        print("Error: could not open video")
        return

    fps = cap.get(cv2.CAP_PROP_FPS)
    if fps <= 0:
        fps = 30
    delay = int(1000 / fps)
    print("Video FPS:", fps, "-> delay:", delay, "ms")

    cv2.namedWindow("HoloRay Line Tracker")
    cv2.setMouseCallback("HoloRay Line Tracker", on_mouse)

    ret, frame = cap.read()
    if not ret:
        print("Error: empty video")
        return

    prev_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        current_frame = frame.copy()
        frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

        # Track each annotation line
        for ann_id, ann in list(annotations.items()):
            pts_prev = ann["features_prev"]
            if pts_prev is None or len(pts_prev) == 0:
                ann["health"] = 0.0
                continue

            pts_next, st, err = cv2.calcOpticalFlowPyrLK(
                prev_gray, frame_gray, pts_prev, None, **lk_params
            )

            if pts_next is None:
                ann["health"] = 0.0
                continue

            st = st.reshape(-1)
            good_old = pts_prev[st == 1]
            good_new = pts_next[st == 1]

            if len(good_new) < 4:
                ann["health"] = len(good_new) / max(len(pts_prev), 1)
                ann["features_prev"] = good_new.reshape(-1, 1, 2)
            else:
                # median translation of features = line motion
                shift = np.median(good_new - good_old, axis=0)  # (2,)
                curve = ann["curve_prev"]
                curve = curve + shift  # move entire line
                ann["curve_prev"] = curve
                ann["features_prev"] = good_new.reshape(-1, 1, 2)
                ann["health"] = float(len(good_new)) / float(len(pts_prev))

            # draw updated line
            curve = ann["curve_prev"]
            curve_int = curve.astype(np.int32).reshape(-1, 1, 2)
            color = ann["color"]
            alpha = max(0.2, min(1.0, ann["health"]))
            draw_color = (
                int(color[0] * alpha),
                int(color[1] * alpha),
                int(color[2] * alpha),
            )
            cv2.polylines(current_frame, [curve_int], isClosed=False, color=draw_color, thickness=3)
            cv2.putText(
                current_frame,
                f"{ann_id[:4]} h={ann['health']:.2f}",
                (curve_int[0, 0, 0] + 5, curve_int[0, 0, 1] - 5),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.4,
                draw_color,
                1,
                cv2.LINE_AA,
            )

        # Show live stroke while drawing
        if drawing and len(stroke_points) > 1:
            pts = np.array(stroke_points, dtype=np.int32).reshape(-1, 1, 2)
            cv2.polylines(current_frame, [pts], isClosed=False, color=(0, 255, 255), thickness=2)

        cv2.imshow("HoloRay Line Tracker", current_frame)
        key = cv2.waitKey(delay) & 0xFF
        if key == 27 or key == ord("q"):
            break

        prev_gray = frame_gray.copy()

    cap.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()