In [None]:
import os
import numpy as np
import random
import cv2
from scipy.spatial.distance import jensenshannon

# Directory Settings
hist_dir = "../../color_histograms/Indoor"
mot_dir = "../../filtered_BoT-SORT_outputs/Indoor"
video_dir = "../../videos/Indoor"
output_mot_dir = "../../filtered_BoT-SORT_outputs/merged"
output_video_dir = "../../videos/Indoor/merged"

os.makedirs(output_mot_dir, exist_ok=True)
os.makedirs(output_video_dir, exist_ok=True)

# 1. histogram loading (retrieved in sorted order for each video)
def load_histograms(hist_dir, video_name):
    histograms = {}
    base_ids = set()
    
    # Get all histograms of the corresponding video
    relevant_files = sorted([f for f in os.listdir(hist_dir) if f.startswith(video_name) and f.endswith(".npy")])
    
    # Extract the numbers immediately before .npy and arrange them in decreasing order
    sorted_files = sorted(relevant_files, key=lambda f: int(f.split("_track")[-1].split(".npy")[0]))
    
    for idx, file in enumerate(sorted_files):
        track_id = int(file.split("_track")[-1].split(".npy")[0])
        histograms[track_id] = np.load(os.path.join(hist_dir, file))
        if idx < 6:  # 6 based on the order of decreasing numbers
            base_ids.add(track_id)
    
    return histograms, base_ids

# 2. Similarity Calculation
def find_most_similar(histograms, target_id, base_ids, used_base_ids):
    min_distance = float("inf")
    best_match = None
    for base_id in base_ids:
        target_hist = histograms.get(target_id)
        base_hist = histograms.get(base_id)

        if target_hist is None or base_hist is None:
            continue  

        print(f"Comparing target_id: {target_id} with base_id: {base_id}")  
        dist = jensenshannon(target_hist, base_hist)
        if dist < min_distance:
            min_distance = dist
            best_match = base_id
    print(f"Best match for target_id {target_id}: {best_match}")  
    return best_match


def generate_merge_map(hist_dir, video_name, mot_dir):
    histograms, base_ids = load_histograms(hist_dir, video_name)
    merge_map = {}
    
    mot_file = os.path.join(mot_dir, f"{video_name}.txt")
    frame_track_map = {}
    with open(mot_file, "r") as f:
        for line in f.readlines():
            frame_id, track_id = map(int, line.strip().split(",")[:2])
            if frame_id not in frame_track_map:
                frame_track_map[frame_id] = set()
            frame_track_map[frame_id].add(track_id)
    
    for target_id in histograms:
        if target_id in base_ids:
            merge_map[target_id] = target_id 
        else:
            for frame_id, track_ids in frame_track_map.items():
                if target_id in track_ids:
                    available_base_ids = base_ids - track_ids
                    if len(available_base_ids) > 0:
                        best_match = find_most_similar(histograms, target_id, available_base_ids, set(merge_map.values()))
                        merge_map[target_id] = best_match
                    else:
                        merge_map[target_id] = min(base_ids)
                    break
    return merge_map


# MOT data correction
def modify_mot_file(mot_path, merge_map, output_path):
    with open(mot_path, "r") as f:
        lines = f.readlines()
    modified_lines = []
    for line in lines:
        data = line.strip().split(",")
        track_id = int(data[1])
        if track_id in merge_map:
            data[1] = str(merge_map[track_id])
        modified_lines.append(",".join(data))
    with open(output_path, "w") as f:
        f.write("\n".join(modified_lines))

def get_color(track_id):
    random.seed(track_id)  
    return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

def visualize_tracking(video_path, mot_path, output_video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return
    
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    track_data = {}
    try:
        with open(mot_path, "r") as f:
            for line in f:
                parts = line.strip().split(",")
                if len(parts) < 6:
                    continue

                try:
                    frame_id = int(parts[0])
                    track_id = int(parts[1])
                    x, y, w, h = map(float, parts[2:6])  
                    x, y, w, h = int(x), int(y), int(w), int(h)  
                except ValueError:
                    continue

                if frame_id not in track_data:
                    track_data[frame_id] = []
                track_data[frame_id].append((track_id, x, y, w, h))

    except FileNotFoundError:
        print(f"Error: File {mot_path} not found.")
        return

    print(f"Loaded tracking data for {len(track_data)} frames.")

    frame_idx = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        if frame_idx in track_data:
            for track_id, x, y, w, h in track_data[frame_idx]:
                color = get_color(track_id)  
                cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)  
                cv2.putText(frame, str(track_id), (x, y - 5), 
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

        out.write(frame)
        frame_idx += 1

    cap.release()
    out.release()
    print(f"Processed {video_path}, saved to {output_video_path}")


# Process all videos
for video_file in os.listdir(video_dir):
    if not video_file.endswith(".mp4"):
        continue
    video_name = os.path.splitext(video_file)[0]
    mot_file = os.path.join(mot_dir, f"{video_name}.txt")
    if not os.path.exists(mot_file):
        print(f"Skipping {video_file}, no corresponding MOT file.")
        continue
    merge_map = generate_merge_map(hist_dir, video_name, mot_dir)
    modified_mot_file = os.path.join(output_mot_dir, f"{video_name}.txt")
    modify_mot_file(mot_file, merge_map, modified_mot_file)
    video_path = os.path.join(video_dir, video_file)
    output_video_path = os.path.join(output_video_dir, video_file)
    visualize_tracking(video_path, modified_mot_file, output_video_path)
print("すべての動画の統合・可視化が完了しました！")


Comparing target_id: 8 with base_id: 4
Best match for target_id 8: 4
Comparing target_id: 10 with base_id: 4
Best match for target_id 10: 4
Comparing target_id: 11 with base_id: 4
Best match for target_id 11: 4
Loaded tracking data for 102 frames.
Processed videos/Indoor/basket_S6T3_post.mp4, saved to videos/Indoor_merged/basket_S6T3_post.mp4
Comparing target_id: 8 with base_id: 3
Best match for target_id 8: 3
Comparing target_id: 10 with base_id: 3
Best match for target_id 10: 3
Comparing target_id: 14 with base_id: 3
Best match for target_id 14: 3
Comparing target_id: 15 with base_id: 3
Comparing target_id: 15 with base_id: 6
Best match for target_id 15: 6
Comparing target_id: 17 with base_id: 3
Comparing target_id: 17 with base_id: 6
Best match for target_id 17: 3
Loaded tracking data for 300 frames.
Processed videos/Indoor/basket_S6T4_post.mp4, saved to videos/Indoor_merged/basket_S6T4_post.mp4
Comparing target_id: 8 with base_id: 4
Best match for target_id 8: 4
Comparing target_id