In [None]:
# YOLO object detector - trained OBJECTS, in this case the javelin (best.pt in this repository): 

# === Import Libraries ===
import torch
import os
import time
import sys
import cv2
import gc
import numpy as np
np.float = np.float64
import psutil
from tqdm.notebook import tqdm
from ultralytics import YOLO
from torchreid.utils import FeatureExtractor
import yaml

# Load class names from dataset from training.yaml
with open("/Users/Christian/Downloads/Javelin training/YOLO javelin_2/yolo_dataset/dataset.yaml", "r") as file:
    dataset_config = yaml.safe_load(file)
class_names = dataset_config["names"]


# === Ensure Torch Uses Metal (MPS) ===
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"✅ Using {device} for computation.")

# === User Input for Paths ===
input_folder = "/Users/Christian/Downloads/Javelin training/analysis"
output_folder = os.path.join(input_folder, "YOLO_Object")
os.makedirs(output_folder, exist_ok=True)
video_files = [f for f in os.listdir(input_folder) if f.lower().endswith(".mp4")]

# === Initialize YOLOv11 ===
body_detector = YOLO("/Users/Christian/Downloads/Javelin training/YOLO javelin_2/yolo_dataset/results/yolo_train2/weights/best.pt")
print("✅ Custom YOLOv11 Model loaded.")

# === Initialize Torchreid (OSNet) ===
extractor = FeatureExtractor(
    model_name='osnet_x1_0',
    model_path="/Users/Christian/Torch ReID/osnet_x1_0_msmt17.pth",
    device="mps"
)
print("✅ OSNet ReID Model Initialized on MPS.")

# === Add BoT-SORT Tracker ===
sys.path.append("/Users/Christian/Christian Home Drive/Christian/Projekte/CURRENTLY RUNNING PROJECTS/CV and NLP/Python_codes and apps/sort/BoT-SORT")
from tracker.bot_sort_reid import BoTSORT
import argparse

args = argparse.Namespace(
    track_high_thresh=0.7,
    track_low_thresh=0.04,
    new_track_thresh=0.24,
    track_buffer=7000,
    max_age=7000,
    n_init=4,
    match_thresh=0.93,
    mot20=False,
    proximity_thresh=0.35,
    appearance_thresh=0.85,
    cmc_method="sparseOptFlow",
    name="BoT-SORT",
    ablation=False,
    with_reid=True,
    lambda_=0.98,
    use_byte=False,
    device='mps',
    imgsz=640
)

# === Process Videos ===
for video_file in video_files:
    input_path = os.path.join(input_folder, video_file)
    csv_path = os.path.join(output_folder, f"{os.path.splitext(video_file)[0]}.csv")
    print(f"\n🎬 Now processing: {video_file}")

    cap = cv2.VideoCapture(input_path)
    if not cap.isOpened():
        print(f"❌ Could not open {input_path}")
        continue

    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    progress_bar = tqdm(total=total_frames, desc="Processing Frames", unit="frame", leave=False)

    frame_num = 0
    columns = ['Frame', 'Track_ID', 'Label', 'X1', 'Y1', 'X2', 'Y2']
    with open(csv_path, 'w') as f:
        f.write(','.join(columns) + '\n')

    tracker = BoTSORT(args, frame_rate=fps)

    with torch.no_grad():
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame_num += 1
            progress_bar.update(1)

            if frame_num % 100 == 0 and psutil.virtual_memory().percent > 80:
                gc.collect()
                torch.mps.empty_cache()

            results = body_detector(frame, conf=0.20, iou=0.55, verbose=False, imgsz=1024)

            tracked_bodies, body_crops, labels = [], [], []

            for result in results:
                for box in result.boxes:
                    conf = box.conf[0].item()
                    x1, y1, x2, y2 = map(int, box.xyxy[0])
                    cls_id = int(box.cls[0])
                    label = class_names[cls_id] if cls_id < len(class_names) else "unknown"
                    if conf > 0.20:
                        tracked_bodies.append([x1, y1, x2, y2, conf])
                        labels.append(label)
                        crop = frame[y1:y2, x1:x2]
                        if crop.shape[0] > 2 and crop.shape[1] > 2:
                            crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
                            body_crops.append(crop_rgb)

            if len(body_crops) == 0:
                continue

            embeddings = extractor(body_crops)
            body_features = np.array([e.cpu().numpy() for e in embeddings])
            tracked_bodies_np = np.array(tracked_bodies, dtype=np.float64).reshape(-1, 5)

            try:
                tracked_results = tracker.update(tracked_bodies_np, frame, body_features)
            except Exception as e:
                print(f"⚠️ Tracking failed on frame {frame_num}: {e}")
                continue

            for i, track in enumerate(tracked_results):
                x1, y1, w, h = map(int, track.tlwh)
                x2, y2 = x1 + w, y1 + h
                track_id = int(track.track_id)
                label = labels[i] if i < len(labels) else "unknown"
                with open(csv_path, 'a') as f:
                    f.write(f"{frame_num},{track_id},{label},{x1},{y1},{x2},{y2}\n")

    cap.release()
    progress_bar.close()
    print(f"✅ Finished {video_file}")

print("\n🎉 All sessions processed!")