In [4]:
import warnings
warnings.filterwarnings("ignore", message="Can't initialize NVML")

import cv2
import numpy as np
from collections import deque, defaultdict
from ultralytics import YOLO
from pathlib import Path
import csv

model = YOLO("/home/debasish/Documents/YOLOv8/weights/pose-weights/yolov8x-832.pt")

input_path = "/home/debasish/Documents/YOLOv8/image052.mp4"
output_base_dir = Path("outputs/reversals")
output_base_dir.mkdir(parents=True, exist_ok=True)

VIDEO_EXTS = [".mp4", ".avi", ".mkv", ".mov"]


# -------------------------
# Create run subfolder
# -------------------------
def create_run_folder(base_dir):
    run_id = 1
    while (base_dir / f"run{run_id}").exists():
        run_id += 1
    run_dir = base_dir / f"run{run_id}"
    run_dir.mkdir(parents=True, exist_ok=True)
    return run_dir

# -------------------------
# Multi-worm video processing using YOLOv8 tracking IDs
# -------------------------
def process_video(video_path, run_dir):
    video_name = Path(video_path).stem
    output_path = run_dir / f"{video_name}_reversals.avi"
    summary_path = run_dir / f"{video_name}_reversals_counted.txt"
    framewise_path = run_dir / f"{video_name}_framewise.csv"

    cap = cv2.VideoCapture(str(video_path))
    width, height = int(cap.get(3)), int(cap.get(4))
    fps = round(cap.get(cv2.CAP_PROP_FPS)) or 30
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))


    BUFFER_SIZE = int(fps)  # 1-second buffer

    # Per-worm buffers and states keyed by local worm_id (restart per video)
    head_buffers = defaultdict(lambda: deque(maxlen=BUFFER_SIZE))
    tail_buffers = defaultdict(lambda: deque(maxlen=BUFFER_SIZE))
    center_buffers = defaultdict(lambda: deque(maxlen=BUFFER_SIZE))
    directions = defaultdict(lambda: "Unknown")
    prev_directions = defaultdict(lambda: "Unknown")
    reversal_counts = defaultdict(int)
    first_flags = defaultdict(lambda: True)
    last_reversal_frame = defaultdict(lambda: -int(fps*0.5))

    colors = {}
    def get_color(worm_id):
        if worm_id not in colors:
            np.random.seed(int(worm_id))
            colors[worm_id] = tuple(np.random.randint(0, 255, size=3).tolist())
        return colors[worm_id]

    frame_count = 0
    framewise_data = []

    # Mapping YOLO track IDs to local worm IDs
    yolotrack_to_local = {}
    next_local_id = 0
    all_local_ids = set()

    # Run YOLOv8 tracker on the full video
    results_gen = model.track(str(video_path), persist=True, tracker="bytetrack.yaml")  # generator

    for results in results_gen:
        frame_count += 1
        frame = results.orig_img.copy()
        frame_dict = {"frame": frame_count}

        for result in results:
            if result.keypoints is None or result.boxes is None or result.boxes.id is None:
                continue

            yolotrack_id = int(result.boxes.id.cpu().numpy()[0])
            # Assign a local worm ID if first seen
            if yolotrack_id not in yolotrack_to_local:
                local_id = next_local_id
                yolotrack_to_local[yolotrack_id] = local_id
                next_local_id += 1
            else:
                local_id = yolotrack_to_local[yolotrack_id]

            all_local_ids.add(local_id)

            kpts = result.keypoints.xy.cpu().numpy()[0]
            head = kpts[0]
            tail = kpts[10]
            box = result.boxes.xyxy.cpu().numpy()[0]
            x1, y1, x2, y2 = box
            center = np.array([(x1+x2)/2, (y1+y2)/2])

            # Update buffers
            head_buffers[local_id].append(head)
            tail_buffers[local_id].append(tail)
            center_buffers[local_id].append(center)

            if len(center_buffers[local_id]) < BUFFER_SIZE:
                continue  # wait until buffer fills

            # Compute body vector
            avg_head = np.mean(head_buffers[local_id], axis=0)
            avg_tail = np.mean(tail_buffers[local_id], axis=0)
            body_vec = avg_head - avg_tail
            body_vec /= (np.linalg.norm(body_vec) + 1e-6)

            # Compute movement vector
            move_vec = center_buffers[local_id][-1] - center_buffers[local_id][0]
            move_norm = np.linalg.norm(move_vec)
            if move_norm < 2.0:
                move_vec = np.zeros_like(move_vec)
            else:
                move_vec /= move_norm

            # Hysteresis thresholds
            dot = np.dot(body_vec, move_vec)
            H_FORWARD = 0.3
            H_REVERSE = -0.3
            if dot > H_FORWARD:
                directions[local_id] = "Forward"
            elif dot < H_REVERSE:
                directions[local_id] = "Reverse"
            else:
                directions[local_id] = prev_directions[local_id]

            # Count reversals with cooldown
            if (not first_flags[local_id] and
                prev_directions[local_id] in ["Forward"] and
                directions[local_id] in ["Reverse"] and
                directions[local_id] != prev_directions[local_id] and
                frame_count - last_reversal_frame[local_id] > int(fps*0.5)):

                reversal_counts[local_id] += 1
                last_reversal_frame[local_id] = frame_count

            prev_directions[local_id] = directions[local_id]
            first_flags[local_id] = False

            # Draw vectors and info
            color = get_color(local_id)
            # cv2.arrowedLine(frame, tuple(avg_tail.astype(int)), tuple(avg_head.astype(int)), color, 3, tipLength=0.3)
            cv2.putText(frame,
                        f"ID {local_id} Dir:{directions[local_id]} Rev:{reversal_counts[local_id]}",
                        (int(center[0])+10, int(center[1])+10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2, cv2.LINE_AA)

            # Framewise 0=Forward, 1=Reverse
            frame_dict[f"worm_{local_id}"] = 0 if directions[local_id] == "Forward" else 1

        # Fill missing worms with blank
        for wid in all_local_ids:
            key = f"worm_{wid}"
            if key not in frame_dict:
                frame_dict[key] = ""

        framewise_data.append(frame_dict)
        out.write(frame)

    cap.release()
    out.release()
    cv2.destroyAllWindows()

    # Save summary
    with open(summary_path, "w") as f:
        for wid, count in reversal_counts.items():
            f.write(f"Worm {wid}: {count} reversals\n")

    # Save framewise CSV
    with open(framewise_path, "w", newline="") as csvfile:
        fieldnames = ["frame"] + [f"worm_{wid}" for wid in sorted(all_local_ids)]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for row in framewise_data:
            writer.writerow(row)

    print(f"[INFO] Video saved at: {output_path}")
    print(f"[INFO] Summary saved at: {summary_path}")
    print(f"[INFO] Framewise CSV saved at: {framewise_path}")


# -------------------------
# Run on single file or folder
# -------------------------
input_path = Path(input_path)

if input_path.is_file() and input_path.suffix.lower() in VIDEO_EXTS:
    run_dir = create_run_folder(output_base_dir)
    process_video(input_path, run_dir)

elif input_path.is_dir():
    videos = [f for f in input_path.iterdir() if f.suffix.lower() in VIDEO_EXTS]
    for v in videos:
        run_dir = create_run_folder(output_base_dir)
        process_video(v, run_dir)
else:
    print(f"[ERROR] No valid video found at {input_path}")


inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.

Example:
    results = model(source=..., stream=True)  # generator of Results objects
    for r in results:
        boxes = r.boxes  # Boxes object for bbox outputs
        masks = r.masks  # Masks object for segment masks outputs
        probs = r.probs  # Class probabilities for classification outputs

video 1/1 (frame 1/3034) /home/debasish/Documents/YOLOv8/image052.mp4: 640x832 8 worms, 25.3ms
video 1/1 (frame 2/3034) /home/debasish/Documents/YOLOv8/image052.mp4: 640x832 8 worms, 26.3ms
video 1/1 (frame 3/3034) /home/debasish/Documents/YOLOv8/image052.mp4: 640x832 8 worms, 25.4ms
video 1/1 (frame 4/3034) /home/debasish/Documents/YOLOv8/image052.mp4: 640x832 8 worms, 25.7ms
video 1/1 (frame 5/3034) /home/debasish/Documents/YOLOv8/image052.mp4: 640x832 8 worms, 