# Description

This notebook refines YOLO video predictions into **tracked identities** across frames. It applies identity assignment, interpolation, and visualization to produce continuous trajectories for each mouse.

# Libraries

In [None]:
from src import (
    get_video_resolution,
    get_nb_frames,
    save_metadata,
    track,
    overlay_annotations_on_video,
    get_yolo_vid_detections_in_json
)

import os
import re
from typing import Dict, Any, List
from tqdm import tqdm

# Main

In [None]:
yolo_predictions_dir = ""
pre_processed_vids_path = ""
output_dir = ""

os.makedirs(output_dir, exist_ok=True)

In [None]:
# Regex to extract trailing frame index from filenames like "..._123.txt"
frame_idx_regex = re.compile(r"(\d+)\.txt$", re.IGNORECASE)

In [None]:
# Iterate over all pre-processed videos
for pp_video_name in tqdm(os.listdir(pre_processed_vids_path), desc="Pre-Processed Videos"):

    # Skip files that are not videos
    if not pp_video_name.endswith(".mp4"):
        continue

    # Path to the current pre-processed video
    pp_video_path = os.path.join(pre_processed_vids_path, pp_video_name)

    # Title = filename without extension (used for folder names etc.)
    video_title = pp_video_name.replace(".mp4", "")

    # YOLO predictions are expected in Ultralytics format:
    #   <yolo_predictions_dir>/<video_title>/labels/*.txt
    video_predicted_labels_path = os.path.join(
        yolo_predictions_dir, f"{video_title}/labels"
    )
    if not os.path.isdir(video_predicted_labels_path):
        print(f"Missing labels dir for {pp_video_name}: {video_predicted_labels_path}")
        continue

    # Output folder for this video (to save metadata + overlay video)
    output_tracked_video_path = os.path.join(output_dir, video_title)
    os.makedirs(output_tracked_video_path, exist_ok=True)

    # --- Video properties ---
    video_width, video_height = get_video_resolution(pp_video_path)
    if video_width is None or video_height is None:
        print(f"Could not get resolution for {pp_video_name}. Skipping...")
        continue

    nb_frames = get_nb_frames(pp_video_path)
    if nb_frames is None or nb_frames == 0:
        print(f"Could not get number of frames for {pp_video_name}. Skipping...")
        continue

    # --- Collect detections ---
    detections: Dict[str, Any] = {}
    visible_percentage = 0.85
    keypoint_names: List[str] = ["nose", "earL", "earR", "tailB"]

    # Parse YOLO .txt labels into JSON-format detections
    detections = get_yolo_vid_detections_in_json(
        video_name=pp_video_name,
        video_predicted_labels_path=video_predicted_labels_path,
        frame_idx_regex=frame_idx_regex,        # regex for extracting frame indices
        video_width=video_width,
        video_height=video_height,
        detections=detections,                  # dict to populate
        visible_percentage=visible_percentage,  # threshold for visible keypoints
        keypoint_names=keypoint_names
    )
    if not detections or len(detections) == 0:
        print(f"No detections for {pp_video_name}. Skipping...")
        continue

    print(
        f"Found {len(detections)} frames with detections for {pp_video_name}. "
        f"Saving detections under 'yolo_detections.json' at {output_tracked_video_path}"
    )
    save_metadata(
        output_dir=output_tracked_video_path,
        metadata_filename="yolo_detections.json",
        metadata=detections
    )

    # --- Tracking parameters ---
    params = dict(
        vid_name=pp_video_name,
        nb_fames=nb_frames,
        detections=detections,
        frames_skip_limit=30,
        scale_factor=0.15,
        penalty_per_missing=100,
        abs_w=video_width,
        abs_h=video_height,
        alpha=0.75,
        epsilon=1e-6,
        cost_threshold=-0.05,
        release_id_at_value=61,
        max_ids=5
    )

    # Perform ID tracking across frames
    custom_tracked_detections = track(**params)
    if not custom_tracked_detections or len(custom_tracked_detections) == 0:
        print(f"Tracking failed for {pp_video_name}. Skipping...")
        continue

    # Save tracked detections
    save_metadata(
        output_dir=output_tracked_video_path,
        metadata_filename="custom_tracked_yolo_detections.json",
        metadata=custom_tracked_detections
    )

    # --- Overlay tracked detections on video ---
    overlay_video_path = os.path.join(output_tracked_video_path, f"{video_title}_tracked.mp4")

    # Color mappings for bounding boxes and keypoints
    color_bbox = {
        "1": (0, 0, 255),
        "2": (0, 191, 255),
        "3": (0, 255, 0),
        "4": (255, 255, 0),
        "5": (255, 0, 191),
    }
    color_kpts = {
        "nose":  (0, 255, 255),
        "earL":  (255, 102, 102),
        "earR":  (140, 102, 255),
        "tailB": (0, 128, 255),
    }
    discard = (False, [])  # (enable_flag, [ids_to_skip])

    # Save overlay video with bounding boxes + keypoints
    overlay_annotations_on_video(
        input_video=pp_video_path,
        annotations=custom_tracked_detections,
        color_bbox=color_bbox,
        color_kpts=color_kpts,
        output_video=overlay_video_path,
        discard=discard
    )