In [None]:
import os
import sys
import time
import cv2
import matplotlib.pyplot as plt
from ultralytics import YOLO
from ultralytics import YOLOv10
from multiprocessing import freeze_support
import torch
import numpy as np
from collections import deque
from scipy.optimize import linear_sum_assignment
from filterpy.kalman import KalmanFilter

freeze_support()

In [None]:
# Set seed for reproducibility
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
n = "x"
# Load the YOLOv10 model
model = YOLOv10(f"chkpts/6DOF/v10{n}/yolov10{n}-detect-6dof/weights/best.pt")
model.to(device)

In [None]:
def increase_confidence_based_on_previous_frame(boxes, confs, ids, previous_ids):
    # Placeholder logic for adjusting confidence based on previous frame
    # In this example, confidence is increased by 0.1 if the ID is consistent with the previous frame
    adjusted_confs = []
    for i, current_id in enumerate(ids):
        if current_id in previous_ids:
            adjusted_confs.append(
                min(confs[i] + 0.1, 1.0)
            )  # Increase confidence slightly
        else:
            adjusted_confs.append(confs[i])  # Keep confidence the same
    return adjusted_confs

In [None]:
def iou(bbox1, bbox2):
    """Compute the intersection over union of two sets of boxes."""
    x1, y1 = np.maximum(bbox1[:2], bbox2[:2])
    x2, y2 = np.minimum(bbox1[2:], bbox2[2:])
    intersection = np.prod(np.maximum(0, [x2 - x1, y2 - y1]))
    area1 = np.prod(bbox1[2:] - bbox1[:2])
    area2 = np.prod(bbox2[2:] - bbox2[:2])
    union = area1 + area2 - intersection
    return intersection / union if union > 0 else 0


class Track:
    def __init__(self, track_id, bbox, feature, max_age=30):
        self.track_id = track_id
        self.bbox = bbox
        self.features = deque([feature], maxlen=100)
        self.kf = self.create_kalman_filter(bbox)
        self.time_since_update = 0
        self.hit_streak = 0
        self.age = 0
        self.max_age = max_age
        self.confidence = 0

    def create_kalman_filter(self, bbox):
        """Create a Kalman filter for tracking bounding boxes."""
        kf = KalmanFilter(dim_x=7, dim_z=4)
        kf.F = np.array(
            [
                [1, 0, 0, 0, 1, 0, 0],
                [0, 1, 0, 0, 0, 1, 0],
                [0, 0, 1, 0, 0, 0, 1],
                [0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0],
                [0, 0, 0, 0, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 1],
            ]
        )
        kf.H = np.array(
            [
                [1, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1],
            ]
        )
        kf.P[
            4:, 4:
        ] *= 1000.0  # Give high uncertainty to the unobservable initial velocities
        kf.P *= 10.0
        kf.R *= 0.01
        kf.x[:4] = bbox
        return kf

    def predict(self):
        """Predict the next state of the track."""
        self.kf.predict()
        self.age += 1
        self.time_since_update += 1
        if self.time_since_update > 0:
            self.hit_streak = 0
        return self.kf.x[:4].reshape(-1)

    def update(self, bbox, feature):
        """Update the track with a new bounding box and feature."""
        self.time_since_update = 0
        self.hit_streak += 1
        self.features.append(feature)
        self.kf.update(bbox)
        self.bbox = self.kf.x[:4].reshape(-1)
        self.confidence = min(
            1.0, self.confidence + 0.1
        )  # Increase confidence with each successful update


class DeepSort:
    def __init__(
        self, max_age=50, n_init=3, max_iou_distance=0.9, max_cosine_distance=0.5
    ):
        self.tracks = []
        self.next_id = 1
        self.max_age = max_age
        self.n_init = n_init
        self.max_iou_distance = max_iou_distance
        self.max_cosine_distance = max_cosine_distance

    def cosine_distance(self, features, targets):
        """Compute the cosine distance between features and targets."""
        if len(features) == 0 or len(targets) == 0:
            return np.zeros((len(features), len(targets)))
        features = np.array(features)
        targets = np.array(targets)
        return 1.0 - np.dot(features, targets.T) / (
            np.linalg.norm(features, axis=1, keepdims=True)
            * np.linalg.norm(targets, axis=1, keepdims=True).T
        )

    def match(self, detections):
        """Match detections to existing tracks based on IOU and appearance."""
        if len(self.tracks) == 0:
            return [], list(range(len(detections))), []

        iou_matrix = np.zeros((len(self.tracks), len(detections)), dtype=np.float32)
        for t, track in enumerate(self.tracks):
            for d, detection in enumerate(detections):
                iou_matrix[t, d] = iou(track.bbox, detection["bbox"])

        matched_indices = linear_sum_assignment(-iou_matrix)
        unmatched_tracks = list(set(range(len(self.tracks))) - set(matched_indices[0]))
        unmatched_detections = list(
            set(range(len(detections))) - set(matched_indices[1])
        )

        return matched_indices, unmatched_tracks, unmatched_detections

    def update_tracks(self, detections, frame):
        """Update the tracks with new detections."""
        matched_indices, unmatched_tracks, unmatched_detections = self.match(detections)

        # Debugging print statements
        print("Matched Indices: ", matched_indices)
        print("Unmatched Tracks: ", unmatched_tracks)
        print("Unmatched Detections: ", unmatched_detections)

        for t, d in zip(*matched_indices):
            self.tracks[t].update(detections[d]["bbox"], detections[d]["feature"])

        # Create new tracks for unmatched detections
        for d in unmatched_detections:
            self.tracks.append(
                Track(self.next_id, detections[d]["bbox"], detections[d]["feature"])
            )
            self.next_id += 1

        # Remove old tracks
        self.tracks = [t for t in self.tracks if t.time_since_update <= self.max_age]

        return self.tracks


# Use the DeepSort class with updated parameters for tracking
deepsort = DeepSort(
    max_age=50,  # Allow tracks to survive longer without updates
    n_init=3,  # Require more consecutive detections to establish a track
    max_iou_distance=0.9,  # Increase IOU threshold for matching
    max_cosine_distance=0.5,  # Increase cosine distance threshold for matching
)

In [None]:
from sympy import det


def euclidean_distance(bbox1, bbox2):
    """Compute the Euclidean distance between the centers of two bounding boxes."""
    center1 = np.array([(bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2])
    center2 = np.array([(bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2])
    return np.linalg.norm(center1 - center2)


def initialize_tracks(detections, max_tools=2):
    """Initialize tracks based on the highest confidence scores."""
    tracks = []
    detections = sorted(detections, key=lambda x: x["conf"], reverse=True)
    tools_count = 0
    for detection in detections:
        if tools_count < max_tools and detection["cls"] == 0:  # Tool
            tools_count += 1
            tracks.append(
                {
                    "id": tools_count,
                    "bbox": detection["bbox"],
                    "confidence": detection["conf"],
                    "type": "tool",
                }
            )
        elif tools_count <= max_tools and detection["cls"] == 1:  # Tooltip
            # Check if the tooltip belongs to an existing tool
            closest_tool = None
            closest_distance = float("inf")
            for track in tracks:
                if track["type"] == "tool":
                    distance = euclidean_distance(track["bbox"], detection["bbox"])
                    if distance < closest_distance:
                        closest_distance = distance
                        closest_tool = track

            if closest_tool:
                tracks.append(
                    {
                        "id": closest_tool["id"],
                        "bbox": detection["bbox"],
                        "confidence": detection["conf"],
                        "type": "tooltip",
                    }
                )

    return tracks


def match_tracks(tracks, detections, max_distance=50):
    """Match detections to existing tracks based on Euclidean distance."""
    matches = []
    for track in tracks:
        best_match = None
        best_distance = max_distance
        for detection in detections:
            distance = euclidean_distance(track["bbox"], detection["bbox"])
            if distance < best_distance:
                best_distance = distance
                best_match = detection
        if best_match:
            matches.append((track, best_match))
    return matches


def update_tracks(tracks, detections, max_distance=50, max_tools=2):
    """Update the tracks with new detections."""
    matched_tracks = []
    tools_tracked = 0
    tool_ids = {track["id"] for track in tracks if track["type"] == "tool"}

    for track, detection in match_tracks(tracks, detections, max_distance):
        track["bbox"] = detection["bbox"]
        track["confidence"] = min(1.0, track["confidence"] + 0.1)
        matched_tracks.append(track)
        if track["type"] == "tool":
            tools_tracked += 1

    # Handle missing tools if fewer than max_tools are tracked
    if tools_tracked < max_tools:
        missing_tools = max_tools - tools_tracked
        unmatched_detections = sorted(detections, key=lambda x: x["conf"], reverse=True)
        for detection in unmatched_detections:
            if detection["cls"] == 0:
                tools_tracked += 1
                track_id = tools_tracked
                matched_tracks.append(
                    {
                        "id": track_id,
                        "bbox": detection["bbox"],
                        "confidence": detection["conf"],
                        "type": "tool",
                    }
                )
                if tools_tracked == max_tools:
                    break

    return matched_tracks


def penalize_and_filter_tracks(tracks):
    """Penalize tracks that don't meet the criteria and filter them."""
    final_tracks = []
    for track in tracks:
        if track["type"] == "tool":
            tooltip_exists = any(
                t["type"] == "tooltip"
                and t["id"] == track["id"]
                and euclidean_distance(t["bbox"], track["bbox"]) < 100
                for t in tracks
            )
            if tooltip_exists:
                final_tracks.append(track)
            else:
                track["confidence"] = max(0, track["confidence"] - 0.2)
                if track["confidence"] > 0.2:
                    final_tracks.append(track)
        elif track["type"] == "tooltip":
            tool_exists = any(
                t["type"] == "tool"
                and t["id"] == track["id"]
                and euclidean_distance(t["bbox"], track["bbox"]) < 200
                for t in tracks
            )
            if tool_exists:
                final_tracks.append(track)
            else:
                track["confidence"] = max(0, track["confidence"] - 0.2)
                if track["confidence"] > 0.2:
                    final_tracks.append(track)

    return final_tracks

In [None]:
def visualize_tracking(model, video_path, n_init=10, max_tools=2):
    # Open the video file
    cap = cv2.VideoCapture(video_path)
    output_path = "data/6DOF/tracked_output.mp4"
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_path, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))

    tracks = []
    count = 0
    while cap.isOpened():
        count += 1
        ret, frame = cap.read()
        if not ret or count > 100:
            break

        # Perform inference
        results = model(frame, verbose=False)

        # Extract the required data from results
        boxes = results[0].boxes.xyxy.cpu().numpy()  # Bounding boxes
        confs = results[0].boxes.conf.cpu().numpy()  # Confidence scores
        classes = results[0].boxes.cls.cpu().numpy()  # Class IDs

        detections = [{"bbox": box, "conf": conf, "cls": cls} for box, conf, cls in zip(boxes, confs, classes)]

        if count % n_init == 0 or len(tracks) == 0:
            # Reinitialize every n_init frames or if no tracks
            tracks = initialize_tracks(detections, max_tools=max_tools)
        else:
            # Update the tracks
            tracks = update_tracks(tracks, detections, max_tools=max_tools)
            tracks = penalize_and_filter_tracks(tracks)

        # Draw bounding boxes and labels with tracking IDs
        for track in tracks:
            x1, y1, x2, y2 = map(int, track["bbox"])
            label = f"{track['type']}-{track['id']}"
            color = (0, 255, 0) if track["type"] == "tool" else (0, 0, 255)
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            cv2.putText(
                frame,
                label,
                (x1, y1 - 10),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.9,
                color,
                2,
            )

        out.write(frame)

    cap.release()
    out.release()

In [None]:
def load_images(input_path):
    start_time = time.time()
    images = []
    if os.path.isdir(input_path):
        paths = [f for f in os.listdir(input_path)]
        # Remove all non-image files
        paths = sorted([f for f in paths if f.endswith((".jpg", ".png"))])
        for filename in paths:
            img_path = os.path.join(input_path, filename)
            img = cv2.imread(img_path)
            images.append(img)
    else:
        cap = cv2.VideoCapture(input_path)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            images.append(frame)
        cap.release()
    print("Loaded", len(images), f"images in {time.time()-start_time} seconds")
    return images

In [None]:
def relabel_and_enforce_order(frame_results, prev_tool_positions):
    tools = []
    tooltips = []
    new_frame_results = []

    for det in frame_results.boxes:
        cls = int(
            det.cls.cpu().numpy()
        )  # Move to CPU, convert to numpy, and ensure it's an integer
        conf = det.conf.cpu().numpy()[0]  # Move to CPU and convert to numpy
        bbox = det.xyxy.cpu().numpy()[0]  # Move to CPU and convert to numpy
        track_id = 1 if bbox[0] < bbox[2] / 2 else 2
        new_det = {"cls": cls, "conf": conf, "bbox": bbox, "id": track_id}

        if cls == 0:
            tools.append(new_det)
        elif cls == 1:
            tooltips.append(new_det)

    # For tools, take the two highest confidence detections
    tools = sorted(tools, key=lambda x: x["conf"], reverse=True)[:2]
    # For tooltips, take the two highest confidence detections
    tooltips = sorted(tooltips, key=lambda x: x["conf"], reverse=True)[:2]

    # If a tool is missing, copy the last bounding box from the previous frame
    if len(tools) == 1 and (prev_tool_positions["tool1"] is not None or prev_tool_positions["tool2"] is not None):
        # Calculate distance between the previous tools and the current tool (and append the further one into tools as it will be the missing tool)
        try:
            tool1_distance = np.linalg.norm(
                np.array(tools[0]["bbox"]) - np.array(prev_tool_positions["tool1"])
            )
        except:
            tool1_distance = float("inf")
        try:
            tool2_distance = np.linalg.norm(
                np.array(tools[0]["bbox"]) - np.array(prev_tool_positions["tool2"])
            )
        except:
            tool2_distance = float("inf")
        if tool1_distance > tool2_distance:
            tools.append(
                {
                    "cls": 0,
                    "conf": max(prev_tool_positions["tool1conf"] - 0.1, 0.0),
                    "bbox": prev_tool_positions["tool1"],
                    "id": 1,
                }
            )
        else:
            tools.append(
                {
                    "cls": 0,
                    "conf": max(prev_tool_positions["tool2conf"] - 0.1, 0.0),
                    "bbox": prev_tool_positions["tool2"],
                    "id": 2,
                }
            )
    # If no tools are detected, just copy the previous tools if exist
    elif len(tools) == 0:
        if prev_tool_positions["tool1"] is not None:
            tools.append(
                {
                    "cls": 0,
                    "conf": max(prev_tool_positions["tool1conf"] - 0.1, 0.0),
                    "bbox": prev_tool_positions["tool1"],
                    "id": 1,
                }
            )
        if prev_tool_positions["tool2"] is not None:
            tools.append(
                {
                    "cls": 0,
                    "conf": max(prev_tool_positions["tool2conf"] - 0.1, 0.0),
                    "bbox": prev_tool_positions["tool2"],
                    "id": 2,
                }
            )

    # If a tooltip is missing, copy the last bounding box from the previous frame
    if len(tooltips) == 1 and (prev_tool_positions["tooltip1"] is not None or prev_tool_positions["tooltip2"] is not None):
        # Calculate distance between the previous tooltips and the current tooltip (and append the further one into tooltips as it will be the missing tooltip)
        try:
            tooltip1_distance = np.linalg.norm(
                np.array(tooltips[0]["bbox"]) - np.array(prev_tool_positions["tooltip1"])
            )
        except:
            tooltip1_distance = float("inf")
        try:
            tooltip2_distance = np.linalg.norm(
                np.array(tooltips[0]["bbox"]) - np.array(prev_tool_positions["tooltip2"])
            )
        except:
            tooltip2_distance = float("inf")
        if tooltip1_distance > tooltip2_distance:            
            tooltips.append(
                {
                    "cls": 1,
                    "conf": max(prev_tool_positions["tooltip2conf"] - 0.1, 0.0),
                    "bbox": prev_tool_positions["tooltip1"],
                    "id": 1,
                }
            )
        else:
            tooltips.append(
                {
                    "cls": 1,
                    "conf": max(prev_tool_positions["tooltip1conf"] - 0.1, 0.0),
                    "bbox": prev_tool_positions["tooltip2"],
                    "id": 2,
                }
            )
    # If no tooltips are detected, just copy the previous tooltips if exist
    elif len(tooltips) == 0:
        if prev_tool_positions["tooltip1"] is not None:
            tooltips.append(
                {
                    "cls": 1,
                    "conf": max(prev_tool_positions["tooltip1conf"] - 0.1, 0.0),    
                    "bbox": prev_tool_positions["tooltip1"],
                    "id": 1,
                }
            )
        if prev_tool_positions["tooltip2"] is not None:
            tooltips.append(
                {
                    "cls": 1,
                    "conf": max(prev_tool_positions["tooltip2conf"] - 0.1, 0.0),
                    "bbox": prev_tool_positions["tooltip2"],
                    "id": 2,
                }
            )

    # Sort tools by x-coordinate (left to right)
    tool1_x1 = tools[0]["bbox"][0]
    tool2_x1 = tools[1]["bbox"][0]
    tool1_x2 = tools[0]["bbox"][2]
    tool2_x2 = tools[1]["bbox"][2]

    # Find largest difference between x1 and x2
    if tool1_x2 - tool1_x1 > tool2_x2 - tool2_x1:
        tools_sorted = sorted(tools, key=lambda x: x["bbox"][0])
        # Sort tooltips by same order as tools
        tooltips = sorted(tooltips, key=lambda x: x["bbox"][0])
    else:
        tools_sorted = sorted(tools, key=lambda x: x["bbox"][2])
        # Sort tooltips by same order as tools
        tooltips = sorted(tooltips, key=lambda x: x["bbox"][2])    

    # Assign IDs and update previous positions: leftmost tool gets ID 1, rightmost tool gets ID 2
    if len(tools_sorted) > 0:
        tools_sorted[0]["id"] = 1  # Leftmost tool
        prev_tool_positions["tool1"] = tools_sorted[0]["bbox"]
    if len(tools_sorted) > 1:
        tools_sorted[1]["id"] = 2  # Rightmost tool
        prev_tool_positions["tool2"] = tools_sorted[1]["bbox"]

    # Ensure no two tools/tooltips have the same ID, adjust if necessary
    if len(tools_sorted) == 2 and tools_sorted[0]["id"] == tools_sorted[1]["id"]:
        tools_sorted[1]["id"] = 2 if tools_sorted[0]["id"] == 1 else 1

    # Assign tooltip IDs based on closest tool or previous positions
    for tip in tooltips:
        tip_center = np.array(
            [
                (tip["bbox"][0] + tip["bbox"][2]) / 2,
                (tip["bbox"][1] + tip["bbox"][3]) / 2,
            ]
        )
        best_tool_id = None
        min_distance = float("inf")

        for tool_id in [1, 2]:  # Ensure we only compare with tools 1 and 2
            tool_bbox = prev_tool_positions[f"tool{tool_id}"]
            if tool_bbox is not None:
                tool_center = np.array(
                    [
                        (tool_bbox[0] + tool_bbox[2]) / 2,
                        (tool_bbox[1] + tool_bbox[3]) / 2,
                    ]
                )
                distance = np.linalg.norm(tool_center - tip_center)
                if distance < min_distance:
                    min_distance = distance
                    best_tool_id = tool_id

        if best_tool_id is not None:
            tip["id"] = best_tool_id
            prev_tool_positions[f"tooltip{best_tool_id}"] = tip["bbox"]
            prev_tool_positions[f"tooltip{best_tool_id}conf"] = tip["conf"]

    # Again ensure no two tools/tooltips have the same ID, adjust if necessary
    if len(tools_sorted) == 2 and tools_sorted[0]["id"] == tools_sorted[1]["id"]:
        tools_sorted[1]["id"] = 2 if tools_sorted[0]["id"] == 1 else 1

    # Combine the tools and tooltips back into the frame results
    new_frame_results = tools_sorted + tooltips

    prev_tool_positions["tool1"] = tools_sorted[0]["bbox"] 
    prev_tool_positions["tool2"] = tools_sorted[1]["bbox"] 
    prev_tool_positions["tooltip1"] = tooltips[0]["bbox"] 
    prev_tool_positions["tooltip2"] = tooltips[1]["bbox"] 
    prev_tool_positions["tool1conf"] = tools_sorted[0]["conf"] 
    prev_tool_positions["tool2conf"] = tools_sorted[1]["conf"] 
    prev_tool_positions["tooltip1conf"] = tooltips[0]["conf"] 
    prev_tool_positions["tooltip2conf"] = tooltips[1]["conf"] 

    return new_frame_results


def process_input(model, input_path, output_path, images):
    # Perform tracking on all images
    track_start = time.time()
    results = model.track(input_path, save=False, verbose=False, stream=True)
    print(f"Tracking complete in {time.time()-track_start} seconds")

    # Initialize tracking correction
    prev_tool_positions = {"tool1": None, "tool2": None, "tooltip1": None, "tooltip2": None, "tool1conf": 0, "tool2conf": 0, "tooltip1conf": 0, "tooltip2conf": 0}
    bounding_boxes = []

    # Process each frame
    process_time = time.time()
    os.makedirs(output_path, exist_ok=True)
    for idx, frame_results in enumerate(results):
        frame = images[idx]
        processed_results = relabel_and_enforce_order(
            frame_results, prev_tool_positions
        )
        for det in processed_results:
            try:
                x1, y1, x2, y2 = map(int, det["bbox"])
                cls = det["cls"]
                bounding_boxes.append([cls, x1, y1, x2, y2])
                # make conf percent
                label = f"{'Tool' if det['cls'] == 0 else 'Tooltip'} #{det['id']}, {det['conf']*100:.2f}%"
                color = "blue" if cls == 0 else "orange"
                cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3)
                cv2.putText(
                    frame, label, (x1, y1 - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 3
                )
            except:
                pass
        if idx % 50 == 0:
            print(f"Processed frame {idx+1}/{len(images)}")
        cv2.imwrite(os.path.join(output_path, f"frame_{idx:04d}.jpg"), frame)

    # Now go into the output directory and create a video
    video_time = time.time()
    output_video_path = os.path.join(output_path, "output.mp4")
    h, w, _ = frame.shape
    frame_files = sorted(
        [f for f in os.listdir(output_path) if f.endswith((".jpg", ".png"))]
    )
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_video_path, fourcc, 20.0, (w, h))
    for frame_file in frame_files:
        frame = cv2.imread(os.path.join(output_path, frame_file))
        os.remove(os.path.join(output_path, frame_file))
        out.write(frame)
    out.release()

    print(f"Video saved in {time.time()-video_time} seconds in {output_video_path}")

    print(f"Processing complete: {time.time()-process_time} seconds")
    return bounding_boxes

In [None]:
import cv2
import numpy as np
import os
import time

CONFIDENCE_THRESHOLD = 0.75


def calculate_black_pixel_ratio(img, bbox):
    """Calculate the ratio of black pixels in the bottom left quadrant to the bottom right quadrant."""
    x1, y1, x2, y2 = map(int, bbox)
    tool_img = img[y1:y2, x1:x2]

    # Convert to grayscale and threshold to find black pixels
    gray = cv2.cvtColor(tool_img, cv2.COLOR_BGR2GRAY)
    _, black_mask = cv2.threshold(gray, 50, 255, cv2.THRESH_BINARY_INV)

    # Split into quadrants
    h, w = black_mask.shape
    bottom_left = black_mask[h // 2 :, : w // 2]
    bottom_right = black_mask[h // 2 :, w // 2 :]

    # Calculate ratios
    bl_ratio = np.sum(bottom_left == 255) / bottom_left.size
    br_ratio = np.sum(bottom_right == 255) / bottom_right.size

    # Return combined ratio
    return bl_ratio + (1 - br_ratio)


def determine_tool_order(img, tool_bboxes):
    """Determine which tool is left and which is right based on black pixel ratio."""
    left_tool_idx = 0
    right_tool_idx = 1

    # Calculate black pixel ratios for both tools
    ratio_1 = calculate_black_pixel_ratio(img, tool_bboxes[0])
    ratio_2 = calculate_black_pixel_ratio(img, tool_bboxes[1])

    if ratio_2 > ratio_1:
        left_tool_idx = 1
        right_tool_idx = 0

    return left_tool_idx, right_tool_idx


def calculate_iou(bbox1, bbox2):
    """Calculate the Intersection over Union (IoU) of two bounding boxes."""
    x1 = max(bbox1[0], bbox2[0])
    y1 = max(bbox1[1], bbox2[1])
    x2 = min(bbox1[2], bbox2[2])
    y2 = min(bbox1[3], bbox2[3])

    inter_area = max(0, x2 - x1) * max(0, y2 - y1)

    bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
    bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])

    iou = inter_area / float(bbox1_area + bbox2_area - inter_area)
    return iou


def calculate_overlap(bbox1, bbox2):
    """Calculate the overlap area of two bounding boxes."""
    x1 = max(bbox1[0], bbox2[0])
    y1 = max(bbox1[1], bbox2[1])
    x2 = min(bbox1[2], bbox2[2])
    y2 = min(bbox1[3], bbox2[3])

    inter_area = max(0, x2 - x1) * max(0, y2 - y1)
    return inter_area


# smaller value of distance from tooltip centre to tool centre multiplied by 1 - overlap ratio
def calculate_distance(tool, tip):
    tip_center = np.array(
        [
            (tip[0] + tip[2]) / 2,
            (tip[1] + tip[3]) / 2,
        ]
    )
    tool_center = np.array(
        [
            (tool[0] + tool[2]) / 2,
            (tool[1] + tool[3]) / 2,
        ]
    )
    distance = np.linalg.norm(tool_center - tip_center)
    overlap = calculate_overlap(tip, tool)
    return distance * (1 - overlap / ((tip[2] - tip[0]) * (tip[3] - tip[1])))

def process_input_simpler(model, input_path, output_path, images):
    prev_tool_positions = {
        "tool1": None,
        "tool2": None,
        "tooltip1": None,
        "tooltip2": None,
        "tool1conf": 0,
        "tool2conf": 0,
        "tooltip1conf": 0,
        "tooltip2conf": 0,
    }
    bounding_boxes = []

    # Perform tracking on all images
    track_start = time.time()
    results = model.track(input_path, save=False, verbose=False, stream=True)
    print(f"Tracking complete in {time.time()-track_start} seconds")

    # Process each frame
    process_time = time.time()
    os.makedirs(output_path, exist_ok=True)
    for idx, frame_results in enumerate(results):
        try:
            frame = images[idx]
            tools = []
            tooltips = []

            # Extract detected tools and tooltips
            for det in frame_results.boxes:
                cls = int(det.cls.cpu().numpy())
                conf = det.conf.cpu().numpy()[0]
                bbox = det.xyxy.cpu().numpy()[0]
                if cls == 0:
                    tools.append({"cls": cls, "conf": conf, "bbox": bbox})
                elif cls == 1:
                    tooltips.append({"cls": cls, "conf": conf, "bbox": bbox})

            # Handle tool detections
            if len(tools) == 2:
                left_idx, right_idx = determine_tool_order(
                    frame, [tools[0]["bbox"], tools[1]["bbox"]]
                )
                tools[left_idx]["id"] = 1
                tools[right_idx]["id"] = 2
                prev_tool_positions["tool1"] = tools[left_idx]["bbox"]
                prev_tool_positions["tool2"] = tools[right_idx]["bbox"]
                prev_tool_positions["tool1conf"] = tools[left_idx]["conf"]
                prev_tool_positions["tool2conf"] = tools[right_idx]["conf"]
            elif len(tools) == 1:
                current_tool = tools[0]
                iou_tool1 = (
                    calculate_iou(prev_tool_positions["tool1"], current_tool["bbox"])
                    if prev_tool_positions["tool1"] is not None
                    else 0
                )
                iou_tool2 = (
                        calculate_iou(prev_tool_positions["tool2"], current_tool["bbox"])
                        if prev_tool_positions["tool2"] is not None
                        else 0
                )
                if iou_tool1 > iou_tool2:
                    current_tool["id"] = 1
                    prev_tool_positions["tool1"] = current_tool["bbox"]
                    prev_tool_positions["tool1conf"] = current_tool["conf"]
                    tools.append(
                        {
                            "cls": 0,
                            "conf": prev_tool_positions["tool2conf"],
                            "bbox": prev_tool_positions["tool2"],
                            "id": 2,
                        }
                    )
                else:
                    current_tool["id"] = 2
                    prev_tool_positions["tool2"] = current_tool["bbox"]
                    prev_tool_positions["tool2conf"] = current_tool["conf"]
                    tools.append(
                        {
                            "cls": 0,
                            "conf": prev_tool_positions["tool1conf"],
                            "bbox": prev_tool_positions["tool1"],
                            "id": 1,
                        }
                    )
            else:
                # Use previous frame tools if confidence is high
                if prev_tool_positions["tool1conf"] > CONFIDENCE_THRESHOLD:
                    tools.append(
                        {
                            "cls": 0,
                            "conf": prev_tool_positions["tool1conf"],
                            "bbox": prev_tool_positions["tool1"],
                            "id": 1,
                        }
                    )
                if prev_tool_positions["tool2conf"] > CONFIDENCE_THRESHOLD:
                    tools.append(
                        {
                            "cls": 0,
                            "conf": prev_tool_positions["tool2conf"],
                            "bbox": prev_tool_positions["tool2"],
                            "id": 2,
                        }
                    )

            # Left tooltip is always 1, right tooltip is always 2. If only one then look at overlap of the tool bbox
            if len(tooltips) == 2:
                # Calculate distance between tooltip and tool centers
                dist1 = calculate_distance(tools[0]["bbox"], tooltips[0]["bbox"])
                dist2 = calculate_distance(tools[1]["bbox"], tooltips[1]["bbox"])
                # multiple by 1 - overlap ratio
                dist1 *= 1 - calculate_overlap(tools[0]["bbox"], tooltips[0]["bbox"]) / (
                    (tooltips[0]["bbox"][2] - tooltips[0]["bbox"][0])
                    * (tooltips[0]["bbox"][3] - tooltips[0]["bbox"][1])
                )
                dist2 *= 1 - calculate_overlap(tools[1]["bbox"], tooltips[1]["bbox"]) / (
                    (tooltips[1]["bbox"][2] - tooltips[1]["bbox"][0])
                    * (tooltips[1]["bbox"][3] - tooltips[1]["bbox"][1])
                )
                tooltips[0]["id"] = 1 if dist1 < dist2 else 2
                tooltips[1]["id"] = 2 if dist1 < dist2 else 1
                prev_tool_positions["tooltip1"] = tooltips[0]["bbox"]
                prev_tool_positions["tooltip2"] = tooltips[1]["bbox"]
                prev_tool_positions["tooltip1conf"] = tooltips[0]["conf"]
                prev_tool_positions["tooltip2conf"] = tooltips[1]["conf"]
            elif len(tooltips) == 1:
                # Check overlap using calculate_overlap
                current_tooltip = tooltips[0]
                iou_tooltip1 = (
                    calculate_overlap(prev_tool_positions["tool1"], current_tooltip["bbox"])
                    if prev_tool_positions["tool1"] is not None
                    else 0
                )
                iou_tooltip2 = (
                    calculate_overlap(prev_tool_positions["tool2"], current_tooltip["bbox"])
                    if prev_tool_positions["tool2"] is not None
                    else 0
                )
                if iou_tooltip1 > iou_tooltip2:
                    current_tooltip["id"] = 1
                    prev_tool_positions["tooltip1"] = current_tooltip["bbox"]
                    prev_tool_positions["tooltip1conf"] = current_tooltip["conf"]
                    tooltips.append(
                        {
                            "cls": 1,
                            "conf": prev_tool_positions["tooltip2conf"],
                            "bbox": prev_tool_positions["tooltip2"],
                            "id": 2,
                        }
                    )
                else:
                    current_tooltip["id"] = 2
                    prev_tool_positions["tooltip2"] = current_tooltip["bbox"]
                    prev_tool_positions["tooltip2conf"] = current_tooltip["conf"]
                    tooltips.append(
                        {
                            "cls": 1,
                            "conf": prev_tool_positions["tooltip1conf"],
                            "bbox": prev_tool_positions["tooltip1"],
                            "id": 1,
                        }
                    )
            else:
                # Use previous frame tooltips if confidence is high
                if prev_tool_positions["tooltip1conf"] > CONFIDENCE_THRESHOLD:
                    tooltips.append(
                        {
                            "cls": 1,
                            "conf": prev_tool_positions["tooltip1conf"],
                            "bbox": prev_tool_positions["tooltip1"],
                            "id": 1,
                        }
                    )
                if prev_tool_positions["tooltip2conf"] > CONFIDENCE_THRESHOLD:
                    tooltips.append(
                        {
                            "cls": 1,
                            "conf": prev_tool_positions["tooltip2conf"],
                            "bbox": prev_tool_positions["tooltip2"],
                            "id": 2,
                        }
                    )
                    
                    

            # If only tooltips are detected and no tools, use the leftmost and rightmost positions
            if len(tools) == 0 and len(tooltips) > 0:
                sorted_tooltips = sorted(
                    tooltips, key=lambda x: x["bbox"][2]
                )  # Sort by bottom-right x-coordinate
                sorted_tooltips[0]["id"] = 1  # Leftmost
                if len(sorted_tooltips) > 1:
                    sorted_tooltips[1]["id"] = 2  # Rightmost

            # Annotate and save frame
            for det in tools + tooltips:
                if det["conf"] > CONFIDENCE_THRESHOLD:
                    x1, y1, x2, y2 = map(int, det["bbox"])
                    cls = det["cls"]
                    bounding_boxes.append([cls, x1, y1, x2, y2])
                    label = f"{'Tool' if cls == 0 else 'Tooltip'} #{det['id']}, {det['conf']*100:.2f}%"
                    color = (255, 0, 0) if det["id"] == 1 else (0, 0, 255)
                    cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3)
                    cv2.putText(
                        frame, label, (x1, y1 - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 3
                    )

            # Update previous positions
            for det in tools:
                if det["id"] == 1:
                    prev_tool_positions["tool1"] = det["bbox"]
                    prev_tool_positions["tool1conf"] = det["conf"]
                elif det["id"] == 2:
                    prev_tool_positions["tool2"] = det["bbox"]
                    prev_tool_positions["tool2conf"] = det["conf"]

            for det in tooltips:
                if det["id"] == 1:
                    prev_tool_positions["tooltip1"] = det["bbox"]
                    prev_tool_positions["tooltip1conf"] = det["conf"]
                elif det["id"] == 2:
                    prev_tool_positions["tooltip2"] = det["bbox"]
                    prev_tool_positions["tooltip2conf"] = det["conf"]

            # Save the frame
            cv2.imwrite(os.path.join(output_path, f"frame_{idx:04d}.jpg"), frame)
            if idx % 50 == 0:
                print(f"Processed frame {idx+1}/{len(images)}")
        except:
            pass

    # Now go into the output directory and create a video
    video_time = time.time()
    output_video_path = os.path.join(output_path, "new_tracking.mp4")
    h, w, _ = frame.shape
    frame_files = sorted(
        [f for f in os.listdir(output_path) if f.endswith((".jpg", ".png"))]
    )
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_video_path, fourcc, 20.0, (w, h))
    for frame_file in frame_files:
        frame = cv2.imread(os.path.join(output_path, frame_file))
        os.remove(os.path.join(output_path, frame_file))
        out.write(frame)
    out.release()

    print(f"Video saved in {time.time()-video_time} seconds in {output_video_path}")
    print(f"Processing complete: {time.time()-process_time} seconds")
    return bounding_boxes

To address the issues with tool and tooltip tracking, we'll adjust the logic for handling cases where only one tool or tooltip is detected, ensuring proper ID assignment and consistent frame-by-frame tracking. Here’s the updated version of `process_input_simpler`:
### Summary of Key Updates:

1. **Tool Detection Logic**: 
   - If two tools are detected, they are assigned IDs based on their position and black pixel ratio.
   - If only one tool is detected, its ID is determined by comparing the IoU with the previous tools.
   - If no tools are detected, we carry over tools from the previous frame only if their confidence was above the threshold.

2. **Tooltip Detection Logic**: 
   - Tooltips are matched to the closest tool, with the same ID as the tool they are closest to.
   - If no tools are detected, but tooltips are, they are assigned IDs based on their horizontal position.

3. **Consistency Check**: 
   - We ensure no duplicate IDs by reordering and correcting IDs if needed.

4. **Updating Previous Positions**: 
   - The previous positions and confidence levels are updated at the end of processing each frame, ensuring that the tracking remains consistent across frames.

This should now handle the edge cases you mentioned and improve the robustness of tracking tools and tooltips across frames.

In [None]:
import cv2
import numpy as np
import os
import time

CONFIDENCE_THRESHOLD = 0.75


def get_bottom_left_coord(bbox):
    """Return the bottom-left coordinate of a bounding box."""
    x1, y2 = bbox[0], bbox[3]
    return (x1, y2)


def get_center_coord(bbox):
    """Return the centre coordinate of a bounding box."""
    x1, y1, x2, y2 = bbox
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    return (cx, cy)


def assign_left_right(detections, prev_left_bbox=None, prev_right_bbox=None):
    """Assign left and right based on the bottom-left coordinates or previous frames."""
    if len(detections) == 2:
        bl1 = get_bottom_left_coord(detections[0]["bbox"])
        bl2 = get_bottom_left_coord(detections[1]["bbox"])
        if bl1[0] < bl2[0]:
            detections[0]["id"] = 1  # Left
            detections[1]["id"] = 2  # Right
        else:
            detections[0]["id"] = 2  # Right
            detections[1]["id"] = 1  # Left
    elif len(detections) == 1:
        if prev_left_bbox is None and prev_right_bbox is None:
            detections[0]["id"] = 1  # Assume it's left if no previous data
        else:
            current_center = get_center_coord(detections[0]["bbox"])
            dist_to_left = (
                np.linalg.norm(
                    np.array(current_center)
                    - np.array(get_center_coord(prev_left_bbox))
                )
                if prev_left_bbox is not None
                else float("inf")
            )
            dist_to_right = (
                np.linalg.norm(
                    np.array(current_center)
                    - np.array(get_center_coord(prev_right_bbox))
                )
                if prev_right_bbox is not None
                else float("inf")
            )
            if dist_to_left < dist_to_right:
                detections[0]["id"] = 1  # Closer to previous left
            else:
                detections[0]["id"] = 2  # Closer to previous right


def process_input_simplest(model, input_path, output_path, images):
    bounding_boxes = []
    prev_tool_left = prev_tool_right = prev_tooltip_left = prev_tooltip_right = None

    # Perform tracking on all images
    track_start = time.time()
    results = model.track(input_path, save=False, verbose=False, stream=True)
    print(f"Tracking complete in {time.time()-track_start} seconds")

    # Process each frame
    process_time = time.time()
    os.makedirs(output_path, exist_ok=True)
    for idx, frame_results in enumerate(results):
        frame = images[idx]
        tools = []
        tooltips = []

        # Extract detected tools and tooltips
        for det in frame_results.boxes:
            cls = int(det.cls.cpu().numpy())
            conf = det.conf.cpu().numpy()[0]
            bbox = det.xyxy.cpu().numpy()[0]
            if cls == 0 and conf > CONFIDENCE_THRESHOLD:
                tools.append({"cls": cls, "conf": conf, "bbox": bbox})
            elif cls == 1 and conf > CONFIDENCE_THRESHOLD:
                tooltips.append({"cls": cls, "conf": conf, "bbox": bbox})

        # Assign left and right to tools and tooltips
        assign_left_right(tools, prev_tool_left, prev_tool_right)
        assign_left_right(tooltips, prev_tooltip_left, prev_tooltip_right)

        # Store the previous bounding boxes for the next frame
        for tool in tools:
            try:
                if tool["id"] == 1:
                    prev_tool_left = tool["bbox"]
                elif tool["id"] == 2:
                    prev_tool_right = tool["bbox"]
            except:
                pass
        for tip in tooltips:
            try:
                if tip["id"] == 1:
                    prev_tooltip_left = tip["bbox"]
                elif tip["id"] == 2:
                    prev_tooltip_right = tip["bbox"]
            except:
                pass

        # Annotate and save frame
        for det in tools + tooltips:
            try:
                x1, y1, x2, y2 = map(int, det["bbox"])
                cls = det["cls"]
                bounding_boxes.append([cls, x1, y1, x2, y2])
                label = f"{'Tool' if cls == 0 else 'Tooltip'} #{det['id']}, {det['conf']*100:.2f}%"
                color = (255, 0, 0) if det["id"] == 1 else (0, 0, 255)
                cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3)
                cv2.putText(
                    frame, label, (x1, y1 - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 3
                )
            except:
                pass

        # Save the frame
        cv2.imwrite(os.path.join(output_path, f"frame_{idx:04d}.jpg"), frame)
        if idx % 50 == 0:
            print(f"Processed frame {idx+1}/{len(images)}")

    # Now go into the output directory and create a video
    video_time = time.time()
    output_video_path = os.path.join(output_path, "easy_tracking.mp4")
    h, w, _ = frame.shape
    frame_files = sorted(
        [f for f in os.listdir(output_path) if f.endswith((".jpg", ".png"))]
    )
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_video_path, fourcc, 20.0, (w, h))
    for frame_file in frame_files:
        frame = cv2.imread(os.path.join(output_path, frame_file))
        os.remove(os.path.join(output_path, frame_file))
        out.write(frame)
    out.release()

    print(f"Video saved in {time.time()-video_time} seconds in {output_video_path}")
    print(f"Processing complete: {time.time()-process_time} seconds")
    return bounding_boxes

In [None]:
from ultralytics import YOLOv10
import time
# "n", "s", "m", "b", "l",
for n in ["x"]: 
    start_time = time.time()
    model = YOLOv10(f"chkpts/6DOF/v10{n}/yolov10{n}-detect-6dof/weights/best.pt").to(
        "cuda"
    )
    print(f"Loaded model in {time.time()-start_time} seconds")
    # process_input(
    #     model,
    #     "data/6DOF/images/val",
    #     f"chkpts/6DOF/v10{n}/tracking",
    #     load_images("data/6DOF/images/val"),
    # )
    # process_input_simpler(
    #     model,
    #     "data/6DOF/images/val",
    #     f"chkpts/6DOF/v10{n}/focused_tracking",
    #     load_images("data/6DOF/images/val"),
    # )
    process_input_simplest(
        model,
        "data/6DOF/images/val",
        f"chkpts/6DOF/v10{n}/easy_tracking",
        load_images("data/6DOF/images/val"),
    )
    print("Done with", n, f"in {time.time()-start_time} seconds")

In [None]:
from ultralytics import YOLO

for n in ["n", "s", "m", "l", "x"]:
    start_time = time.time()
    model = YOLO(f"chkpts/6DOF/v8{n}/yolov8{n}-detect-6dof/weights/best.pt").to(
        "cuda"
    )
    print(f"Loaded model in {time.time()-start_time} seconds")
    # process_input(
    #     model,
    #     "data/6DOF/images/val",
    #     f"chkpts/6DOF/v8{n}/tracking",
    #     load_images("data/6DOF/images/val"),
    # )
    # process_input_simpler(
    #     model,
    #     "data/6DOF/images/val",
    #     f"chkpts/6DOF/v8{n}/focused_tracking",
    #     load_images("data/6DOF/images/val"),
    # )
    process_input_simplest(
        model,
        "data/6DOF/images/val",
        f"chkpts/6DOF/v8{n}/easy_tracking",
        load_images("data/6DOF/images/val"),
    )
    print("Done with", n, f"in {time.time()-start_time} seconds")

In [None]:
import cv2


def draw_bounding_boxes_with_opacity(
    image_path, bounding_boxes, output_path, alpha_decay=0.1
):
    """
    Draw bounding boxes on an image with decreasing opacity for older boxes.

    :param image_path: Path to the final image.
    :param bounding_boxes: List of bounding boxes in the format [class_id, x1, y1, x2, y2].
    :param output_path: Path to save the output image.
    :param alpha_decay: Amount by which opacity decreases for older boxes.
    """
    image = cv2.imread(image_path)
    overlay = image.copy()

    # Sort bounding boxes by class_id to apply opacity correctly
    bounding_boxes.reverse()

    for i, box in enumerate(bounding_boxes):
        class_id, x1, y1, x2, y2 = box
        if class_id == 0:
            continue
        alpha = 1 - i * alpha_decay
        alpha = max(alpha, 0.01)  # Ensure a minimum opacity level

        # set green
        color = (0, 255, 0)

        # Draw the bounding box with reduced opacity with no fill (just the lines)
        cv2.rectangle(overlay, (x1, y1), (x2, y2), color, 2)
        cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)

    cv2.imwrite(output_path, image)
    print(f"Image saved with bounding boxes at {output_path}")

# Paths
image_path = "data/6DOF/images/test/test5_319.png"
output_path = "final_image_with_boxes.png"

# Draw the bounding boxes
# draw_bounding_boxes_with_opacity(image_path, bounding_boxes, output_path)

In [None]:
import os
import cv2
from ultralytics import YOLOv10


def process_images(input_dir, output_dir, model):
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Process each image in the input directory
    for filename in sorted(os.listdir(input_dir)):
        if filename.endswith(".png"):
            img_path = os.path.join(input_dir, filename)
            img = cv2.imread(img_path)
            # if label file already exists, skip
            if os.path.exists(os.path.join(output_dir, filename.replace(".png", ".txt"))):
                continue
            # Perform inference using the YOLO model
            results = model(img, verbose=False)

            # Get the top 2 confidence values for each class (tool = 0, tooltip = 1)
            detections = results[0].boxes
            # if no detections then write empty file and continue
            if len(detections) == 0:
                with open(os.path.join(output_dir, filename.replace(".png", ".txt")), "w") as f:
                    f.write("")
                continue
            tools = [det for det in detections if det.cls == 0]
            tooltips = [det for det in detections if det.cls == 1]

            tools_sorted = sorted(tools, key=lambda x: x.conf, reverse=True)[:2]
            tooltips_sorted = sorted(tooltips, key=lambda x: x.conf, reverse=True)[:2]

            # Combine and format detections for YOLO format
            all_detections = tools_sorted + tooltips_sorted
            yolo_format_data = []
            for det in all_detections:
                try:
                    bbox = det.xywh[0].cpu().numpy()
                    cls = int(det.cls.cpu().numpy())
                    x_center, y_center, width, height = bbox[0], bbox[1], bbox[2], bbox[3]
                    # Normalize coordinates by image dimensions
                    x_center /= img.shape[1]
                    y_center /= img.shape[0]
                    width /= img.shape[1]
                    height /= img.shape[0]
                    yolo_format_data.append(
                        f"{cls} {x_center} {y_center} {width} {height}/n"
                    )
                except:
                    continue

            # Save the results in the corresponding txt file
            output_txt_path = os.path.join(
                output_dir, filename.replace(".png", ".txt").replace(".jpg", ".txt")
            )
            with open(output_txt_path, "w") as f:
                f.writelines(yolo_format_data)

            print(f"Processed {filename} and saved results to {output_txt_path}")


if __name__ == "__main__":
    # Load your YOLO model
    model = YOLOv10("chkpts/6DOF/v10x/yolov10x-detect-6dof/weights/best.pt").to(device)

    for i in range(1, 25):
        if i == 5:
            continue
        input_dir = f"H:/Data/6DOF/Test {i} png"
        output_dir = f"H:/Data/6DOF/Test {i} txt"
        if os.path.exists(input_dir):
            process_images(input_dir, output_dir, model)
        else:
            print(f"Directory {input_dir} does not exist. Skipping...")

In [None]:
import os
import cv2
import numpy as np
from scipy.spatial.transform import Rotation as R
from ultralytics import YOLOv10


def draw_arrow(image, start_point, end_point, color, thickness=2):
    """Draw an arrow on the image from start_point to end_point."""
    cv2.arrowedLine(image, start_point, end_point, color, thickness, tipLength=0.3)


def quaternion_to_axes(q):
    """Convert quaternion to x, y, z axes vectors."""
    rotation = R.from_quat(q)
    x_axis = rotation.apply([1, 0, 0])
    y_axis = rotation.apply([0, 1, 0])
    z_axis = rotation.apply([0, 0, 1])
    return x_axis, y_axis, z_axis


def process_single_image(image_path, label_file, model):
    """Process a single image and draw arrows for the tooltip positions."""
    # Load the image
    image = cv2.imread(image_path)

    # Perform YOLO detection
    results = model(image, save=False, verbose=False)
    detections = results[0].boxes

    # Get the top 2 confidence detections for tools and tooltips
    tools = [det for det in detections if det.cls == 0]
    tooltips = [det for det in detections if det.cls == 1]

    tools_sorted = sorted(tools, key=lambda x: x.conf, reverse=True)[:2]
    tooltips_sorted = sorted(tooltips, key=lambda x: x.conf, reverse=True)[:2]

    # Load pose data from label file
    with open(label_file, "r") as f:
        lines = f.readlines()

    fenestrated_pose = list(
        map(float, lines[2].split()[1:])
    )  # Fenestrated pose (x, y, z, qx, qy, qz, qw)
    curved_pose = list(
        map(float, lines[3].split()[1:])
    )  # Curved pose (x, y, z, qx, qy, qz, qw)

    # Define colors for the arrows (Fenestrated: Blue, Curved: Orange)
    arrow_colors = {
        1: (255, 0, 0),  # Blue for Fenestrated
        2: (0, 165, 255),  # Orange for Curved
    }

    # Map the tool ID to the respective pose
    tool_pose_map = {1: fenestrated_pose, 2: curved_pose}

    # Draw arrows based on tooltip positions and quaternion data
    for tool_id, tooltip in enumerate(tooltips_sorted):
        bbox = tooltip.xyxy[0].cpu().numpy().astype(int)
        center_x = (bbox[0] + bbox[2]) // 2
        center_y = (bbox[1] + bbox[3]) // 2

        pose = tool_pose_map[tool_id+1]
        q = pose[3:7]  # Extract quaternion (qx, qy, qz, qw)

        # Get orientation vectors (axes) from quaternion
        x_axis, y_axis, z_axis = quaternion_to_axes(q)

        # Draw arrows for x, y, z axes
        arrow_length = 50  # Length of arrows
        end_point_x = (
            int(center_x + x_axis[0] * arrow_length),
            int(center_y + x_axis[1] * arrow_length),
        )
        end_point_y = (
            int(center_x + y_axis[0] * arrow_length),
            int(center_y + y_axis[1] * arrow_length),
        )
        end_point_z = (
            int(center_x + z_axis[0] * arrow_length),
            int(center_y + z_axis[1] * arrow_length),
        )

        draw_arrow(
            image, (center_x, center_y), end_point_x, arrow_colors[tool_id+1], thickness=2
        )
        draw_arrow(
            image, (center_x, center_y), end_point_y, (0, 255, 0), thickness=2
        )  # Green for y-axis
        draw_arrow(
            image, (center_x, center_y), end_point_z, (255, 0, 255), thickness=2
        )  # Magenta for z-axis

    # Save the image with arrows drawn
    output_image_path = image_path.replace(".png", "_with_arrows.png")
    cv2.imwrite(output_image_path, image)
    print(f"Processed image saved as {output_image_path}")


if __name__ == "__main__":
    # Load your YOLO model
    model = YOLOv10("chkpts/6DOF/v10x/yolov10x-detect-6dof/weights/best.pt").to("cuda")

    # Specify the input image and label file
    image_path = "H:/Data/6DOF/Test 2 png/test2_0000.png"
    label_file = "H:/Data/6DOF/Test 2/0.txt"

    # Process the image
    process_single_image(image_path, label_file, model)