In [None]:
def run_inference(
    model: YOLO,
    frame: np.ndarray,
    frame_number: int,
    *,
    tracker,                 # e.g. tracker = sv.ByteTrack()
    state: dict,             # must hold {"ball_id": Optional[int], "centers": List[Tuple[x,y]]}
    save_csv: bool = False,
    writer: Optional = None,
    conf_threshold: float = 0.25,
    iou_threshold: float = 0.45
) -> np.ndarray:
    """
    Runs YOLO → ByteTrack → annotate → draw trajectory → optionally log to CSV.

    state must be initialized with:
      state = {"ball_id": None, "centers": []}
    """
    # sanity
    if save_csv and writer is None:
        raise ValueError("`writer` is required when save_csv=True")
    if not save_csv and writer is not None:
        raise ValueError("`writer` only allowed when save_csv=True")

    # 1) Detect + NMS
    yolo_res = model(frame, conf=conf_threshold, iou=iou_threshold)[0]
    dets     = sv.Detections.from_ultralytics(yolo_res)

    # 2) Track
    tracked = tracker.update_with_detections(dets)

    # 3) Lock in ball_id
    if state.get("ball_id") is None and tracked.tracker_id.size > 0:
        best_idx          = int(np.argmax(tracked.confidence))
        state["ball_id"]  = tracked.tracker_id[best_idx]
    ball_id = state["ball_id"]

    # 4) Filter to our ball (or fallback)
    if ball_id is not None:
        mask = [tid == ball_id for tid in tracked.tracker_id]
        if any(mask):
            keep_idxs = [i for i,m in enumerate(mask) if m]
            xyxy      = tracked.xyxy[keep_idxs]
            confs     = tracked.confidence[keep_idxs]
            tids      = [tracked.tracker_id[i] for i in keep_idxs]
        else:
            xyxy  = dets.xyxy
            confs = dets.confidence
            tids  = [None] * len(dets.xyxy)
    else:
        xyxy  = dets.xyxy
        confs = dets.confidence
        tids  = [None] * len(dets.xyxy)

    # 5) Prepare to draw
    dets_to_draw = sv.Detections(
        xyxy       = np.array(xyxy),
        confidence = np.array(confs),
        class_id   = np.zeros(len(xyxy), dtype=int)
    )

    # 6) Annotate boxes
    img = sv.BoxAnnotator().annotate(frame, dets_to_draw)

     # 7) Compute & store center *only* if it’s our tracked ball
    if ball_id is not None:
        # find the drawn box whose tid matches ball_id
        for box, tid in zip(dets_to_draw.xyxy, tids):
            if tid == ball_id:
                x1, y1, x2, y2 = box
                cx = int((x1 + x2) / 2)
                cy = int((y1 + y2) / 2)
                # initialize the list on lock-in
                if "centers" not in state:
                    state["centers"] = []
                state["centers"].append((cx, cy))
                break

     # 8) Draw the trajectory for only that ball
    pts = state.get("centers", [])
    if len(pts) > 1:
        cv2.polylines(
            img,
            [np.array(pts, dtype=np.int32)],
            isClosed=False,
            color=(0, 255, 0),
            thickness=2
        )

    # 9) Overlay frame number
    cv2.putText(
        img, str(frame_number),
        org=(10, 30),
        fontFace=cv2.FONT_HERSHEY_COMPLEX,
        fontScale=1,
        color=(255, 0, 0),
        thickness=2,
        lineType=cv2.LINE_AA
    )

    # 10) CSV logging
    if save_csv:
        for det_idx, (xy, tid, conf) in enumerate(
            zip(dets_to_draw.xyxy, tids, dets_to_draw.confidence), start=1
        ):
            x1, y1, x2, y2 = map(float, xy)
            c_x = (x1 + x2) / 2
            c_y = (y1 + y2) / 2
            writer.writerow([
                frame_number,
                tid,
                det_idx,
                conf,
                c_x, c_y,
                x1, y1, x2, y2
            ])

    return img