# Libraries

In [None]:
import numpy as np
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment
import os
from ffprobe import FFProbe
import cv2
import json

# Global Variables

In [None]:
# ===================== CONFIGURATION / HYPERPARAMETERS =======================

# Maximum allowable distance or mismatch for a valid match.
# Typically, you'd define a function that returns a "cost" or "distance",
# and if cost > COST_THRESHOLD, you treat it as an invalid match.
COST_THRESHOLD = 5

# If a track is unmatched for this many consecutive frames, we delete it.
MAX_UNMATCHED_FRAMES = 100

# If the detection and track have a cost above this threshold, set the cost to large.
LARGE_COST = 1e9

# Next ID for new tracks
TRACK_ID_START = 1

# =============================================================================

# Helper Functions

In [None]:
def iou_bbox(b1, b2):
    """
    Compute IoU of two bounding boxes in (x1, y1, x2, y2) format.
      b1, b2 = (x1, y1, x2, y2) in the same coordinate system.
    """
    # Intersection
    ix1 = max(b1[0], b2[0])
    iy1 = max(b1[1], b2[1])
    ix2 = min(b1[2], b2[2])
    iy2 = min(b1[3], b2[3])

    iw = max(0., ix2 - ix1)
    ih = max(0., iy2 - iy1)
    inter = iw * ih

    # Union
    area1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
    area2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
    union = area1 + area2 - inter
    if union < 1e-9:
        return 0.
    return inter / union

In [None]:
def center_distance(b1, b2):
    """
    Euclidean distance between centers of two bounding boxes
    in (x1, y1, x2, y2) format.
    """
    cx1 = 0.5*(b1[0] + b1[2])
    cy1 = 0.5*(b1[1] + b1[3])
    cx2 = 0.5*(b2[0] + b2[2])
    cy2 = 0.5*(b2[1] + b2[3])
    return np.hypot(cx1 - cx2, cy1 - cy2)

In [None]:
def xywh_to_xyxy(b):
    """
    Convert bounding box from (x_center, y_center, w, h) normalized or pixel
    to corner format (x1, y1, x2, y2). You can adapt for your coordinate system.
    """
    x_c, y_c, w, h = b
    x1 = x_c - w/2
    y1 = y_c - h/2
    x2 = x_c + w/2
    y2 = y_c + h/2
    return (x1, y1, x2, y2)

# Helper Classes

In [None]:
class MouseKalmanFilter:
    """
    A KalmanFilter wrapper for bounding box [x_center, y_center, width, height].
    You can refine this for better motion modeling.
    """
    def __init__(self, init_bbox, init_frame=0):
        # init_bbox: (x_center, y_center, w, h)
        # We'll track [x, y, s, r] with s ~ scale, r ~ aspect ratio (some standard approach).
        # Or you can do a simpler approach [x, y, w, h] directly.
        # This example is loosely adapted from e.g. SORT/AB3DMOT style filters.
        self.kf = KalmanFilter(dim_x=7, dim_z=4)
        # State x = [x, y, s, r, vx, vy, vs]
        # z = [x, y, s, r]
        dt = 1.
        self.kf.F = np.array([
            [1, 0, 0, 0, dt, 0,  0],
            [0, 1, 0, 0, 0,  dt, 0],
            [0, 0, 1, 0, 0,  0,  dt],
            [0, 0, 0, 1, 0,  0,  0 ],
            [0, 0, 0, 0, 1,  0,  0 ],
            [0, 0, 0, 0, 0,  1,  0 ],
            [0, 0, 0, 0, 0,  0,  1 ]
        ], dtype=float)
        self.kf.H = np.array([
            [1, 0, 0, 0, 0, 0, 0],
            [0, 1, 0, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0, 0],
            [0, 0, 0, 1, 0, 0, 0]
        ], dtype=float)
        
        # Process noise and measurement noise are hyperparameters
        self.kf.P[4:,4:] *= 1000.  # Large initial uncertainty for velocities
        self.kf.P *= 10.
        self.kf.R[2:,2:] *= 10.  # Scale, ratio measurement noise

        # Initialize
        x, y, w, h = init_bbox
        s = w*h  # scale ~ area
        r = w/float(h+1e-6)  # aspect ratio
        self.kf.x[:4] = np.array([x, y, s, r]).reshape(-1,1)

        self.update(init_bbox)

    def predict(self):
        self.kf.predict()
        return self.get_bbox()

    def update(self, bbox):
        # bbox is (x, y, w, h)
        # Convert to [x, y, s, r]
        x, y, w, h = bbox
        s = w*h
        r = w/float(h+1e-6)
        z = np.array([x, y, s, r])
        self.kf.update(z)
        return self.get_bbox()

    def get_bbox(self):
        """
        Convert [x, y, s, r] in self.kf.x to (x, y, w, h).
        """
        x, y, s, r = self.kf.x[:4].reshape(-1)

        if s <= 0:
            # clamp s to small epsilon
            s = 1e-6
            self.kf.x[2] = s
            
        w = np.sqrt(s*r)
        h = np.sqrt(s/r)


        return (x, y, w, h)

In [None]:
class Track:
    """
    Represents a tracked mouse. Stores KalmanFilter, ID, keypoints, etc.
    """
    def __init__(self, detection, track_id):
        self.id = track_id
        # detection['bbox'] is assumed in (x_center, y_center, w, h)
        self.kf = MouseKalmanFilter(detection['bbox'])
        self.keypoints = detection.get('keypoints', None)  # store if you want
        self.time_since_update = 0
        self.hits = 1

    def predict(self):
        predicted_bbox = self.kf.predict()
        self.time_since_update += 1
        return predicted_bbox

    def update(self, detection):
        # detection['bbox'] is (x_center, y_center, w, h)
        self.kf.update(detection['bbox'])
        self.keypoints = detection.get('keypoints', None)
        self.time_since_update = 0
        self.hits += 1

    def get_bbox_xyxy(self):
        """
        Return bounding box in (x1, y1, x2, y2) for use in cost calculations or display.
        """
        x, y, w, h = self.kf.get_bbox()
        return xywh_to_xyxy((x, y, w, h))

    def get_bbox_xywh(self):
        return self.kf.get_bbox()

# Helper Functions

In [None]:
def compute_cost(track: Track, detection: dict, alpha=0.5):
    """
    Example cost function:
      cost = alpha * (1 - iou) + (1-alpha) * center_distance
    where 0 <= cost < 2. Lower cost => better match.
    
    You could incorporate keypoints, color histograms, etc. 
    """
    track_xyxy = track.get_bbox_xyxy()
    det_xyxy   = xywh_to_xyxy(detection['bbox'])

    iou_val   = iou_bbox(track_xyxy, det_xyxy)
    cdist_val = center_distance(track_xyxy, det_xyxy)

    # Weighted combination
    cost = alpha * (1.0 - iou_val) + (1-alpha) * (cdist_val / 100.0)
    # The center distance is scaled by e.g. 100 to keep it in a smaller range.

    return cost

In [None]:
def associate_detections_to_tracks(tracks, detections, alpha=0.5):
    """
    Build a cost matrix of shape (len(tracks), len(detections)) 
    using compute_cost, then solve the assignment using Hungarian.
    
    Returns:
      matched_pairs: list of (track_idx, detection_idx)
      unmatched_tracks: set of track indices
      unmatched_detections: set of detection indices
    """
    if len(tracks) == 0 or len(detections) == 0:
        return [], set(range(len(tracks))), set(range(len(detections)))

    cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
    for i, trk in enumerate(tracks):
        for j, det in enumerate(detections):
            c = compute_cost(trk, det, alpha=alpha)
            # If the cost is too large, we can clamp it or ignore it.
            if c > COST_THRESHOLD:  
                cost_matrix[i, j] = LARGE_COST
            else:
                cost_matrix[i, j] = c

    # print(cost_matrix)
    row_idx, col_idx = linear_sum_assignment(cost_matrix)
    matched_pairs = []
    for r, c in zip(row_idx, col_idx):
        # If cost is large, treat as unmatched
        if cost_matrix[r, c] >= LARGE_COST:
            continue
        matched_pairs.append((r, c))

    matched_track_indices = set([m[0] for m in matched_pairs])
    matched_det_indices   = set([m[1] for m in matched_pairs])

    unmatched_tracks = set(range(len(tracks))) - matched_track_indices
    unmatched_detections = set(range(len(detections))) - matched_det_indices
    return matched_pairs, unmatched_tracks, unmatched_detections

In [None]:
def track_multi_mice(all_detections):
    """
    Production-level multi-object tracking logic:
      - For each frame, predict track positions
      - Build cost matrix + Hungarian assignment
      - Update matched tracks, handle unmatched
      - Create new tracks for unmatched detections
      - Remove 'stale' tracks
    Input:
      all_detections: dict { frame_idx: [ { 'bbox':(x,y,w,h), 'keypoints':... }, ... ] }
    Return:
      results_per_frame: dict { frame_idx: { track_id: { 'bbox':..., 'keypoints':... } } }
    """
    frame_indices = sorted(all_detections.keys())
    active_tracks = []
    next_id = TRACK_ID_START

    # We'll store the final bounding boxes for each track per frame.
    results_per_frame = {}

    for frame_idx in frame_indices:
        detections = all_detections[frame_idx]

        # 1) PREDICT
        for trk in active_tracks:
            trk.predict()

        # 2) ASSOCIATE
        matched_pairs, unmatched_tracks, unmatched_dets = associate_detections_to_tracks(active_tracks, detections)

        # 3) UPDATE matched tracks
        for (trk_idx, det_idx) in matched_pairs:
            active_tracks[trk_idx].update(detections[det_idx])

        # 4) For unmatched tracks, just keep them with increased time_since_update
        #    (the track.predict() already incremented time_since_update)

        # 5) CREATE new tracks for unmatched detections
        for ud in unmatched_dets:
            new_trk = Track(detections[ud], next_id)
            active_tracks.append(new_trk)
            next_id += 1

        # 6) REMOVE tracks that have been unmatched too long
        survived_tracks = []
        for trk in active_tracks:
            if trk.time_since_update < MAX_UNMATCHED_FRAMES:
                survived_tracks.append(trk)
        active_tracks = survived_tracks

        # 7) Collect results
        frame_result = {}
        for trk in active_tracks:
            bbox = trk.get_bbox_xywh()  # (x_center, y_center, w, h)
            frame_result[trk.id] = {
                'bbox': bbox,
                'keypoints': trk.keypoints
            }
        results_per_frame[frame_idx] = frame_result

    return results_per_frame

In [None]:
def get_video_resolution(filename):
    """
    Returns (width, height) for the first video stream found in `filename`.
    """
    metadata = FFProbe(filename)
    for stream in metadata.streams:
        if stream.is_video():
            print(dir(stream))
            return (int(stream.width), int(stream.height))
        
    return (None, None)

In [None]:
def isPointInBBox(x, y, x1, y1, x2, y2):
  return (
    x >= x1 and x <= x2 and
    y >= y1 and y <= y2
  )

In [None]:
def yolo_txt_to_detection(
    txt_path, 
    frame_index,  
    image_width, 
    image_height,
    mAnnotated_flag,
    visiblePercentage,
    keypoint_names=None
):
    """
    Reads a YOLO-like .txt (with bbox + 4 keypoints in normalized coords),
    and returns a dictionary in the original annotation style:

    {
      "image_filename": [
        {
          "bbox": {"x1":..., "y1":..., "x2":..., "y2":...},
          "keypoints": {
            "nose":  [...],
            "earL":  [...],
            "earR":  [...],
            "tailB": [...]
          }
        },
        ...
      ]
    }
    """
    if keypoint_names is None:
        # You can change the order or number of keypoints as needed:
        keypoint_names = ["nose", "earL", "earR", "tailB"]

    annotations = {frame_index: []}

    with open(txt_path, "r") as f:
        lines = f.readlines()

    for line in lines:
        line = line.strip()
        if not line:
            continue

        tokens = line.split()
        # The first 5 tokens are class_id, x_center, y_center, w, h
        class_id    = int(tokens[0])
        x_center_n  = float(tokens[1])
        y_center_n  = float(tokens[2])
        w_n         = float(tokens[3])
        h_n         = float(tokens[4])

        # Denormalize bounding box
        x_center = x_center_n * image_width
        y_center = y_center_n * image_height
        w        = w_n * image_width
        h        = h_n * image_height

        x1 = x_center - w / 2
        y1 = y_center - h / 2
        x2 = x_center + w / 2
        y2 = y_center + h / 2

        if (x1 == x2 or y1 == y2):
            continue

        # Next tokens: each keypoint has x_kpt_n, y_kpt_n, v_kpt
        # For 4 keypoints, that's 12 tokens, starting at index = 5
        keypoints_dict = {}
        num_kpts = len(keypoint_names)
        
        # i.e. for 4 keypoints, range(4) => 0..3
        for i in range(num_kpts):
            x_kpt_n = float(tokens[5 + 3*i])
            y_kpt_n = float(tokens[5 + 3*i + 1])
            v_kpt   = float(tokens[5 + 3*i + 2])

            # denormalize
            x_kpt = x_kpt_n * image_width
            y_kpt = y_kpt_n * image_height

            if not(isPointInBBox(x_kpt, y_kpt, x1, y1, x2, y2)):
                continue
            
            kpt_name = keypoint_names[i]
            
            keypoints_dict[kpt_name] = [int(x_kpt), int(y_kpt), 2 if v_kpt > visiblePercentage else 1]

        annotations[frame_index].append({
            "bbox": (x_center, y_center, w, h),
            "bbox_xY": {
                "x1": x1,
                "y1": y1,
                "x2": x2,
                "y2": y2
            },
            "keypoints": keypoints_dict
            # "mAnnotated": mAnnotated_flag
        })

    return annotations

In [None]:
def overlay_annotations_on_video(input_video, annotations, color_box, color_kpt, output_video="output.mp4", discard=(False, [])):
    cap = cv2.VideoCapture(input_video)

    # Retrieve video properties
    width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps    = cap.get(cv2.CAP_PROP_FPS)
    
    # Define codec and create VideoWriter to save the output
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # or 'XVID'/'avc1' etc.
    out    = cv2.VideoWriter(output_video, fourcc, fps, (width, height))

    frame_index = 1  # or 0, depending on how your annotations are keyed
    while True:
        ret, frame = cap.read()
        if not ret:
            break  # no more frames in video
        
        if frame_index in annotations:
            # Get all mice info for this frame
            for mouse_id, mouse_data in annotations[frame_index].items():

                if discard[0] and (mouse_id in discard[1]):
                    continue
                
                # Extract bounding box
                bbox = mouse_data['bbox']
                x1, y1 = int(bbox['x1']), int(bbox['y1'])
                x2, y2 = int(bbox['x2']), int(bbox['y2'])

                # Draw the bounding box
                # color_box = (0, 255, 255)  # e.g. yellow
                cv2.rectangle(frame, (x1, y1), (x2, y2), color_box[mouse_id], 2)

                # (Optional) Label the mouse ID
                cv2.putText(frame, f"Mouse {mouse_id}", (x1, y1 - 5),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, color_box[mouse_id], 2)

                # Draw each keypoint
                keypoints = mouse_data['keypoints']
                for kpt_name, (kx, ky, conf) in keypoints.items():
                    # conf is a confidence score you can use if needed
                    kx, ky = int(kx), int(ky)
                    # color_kpt = (0, 255, 0)  # e.g. green
                    cv2.circle(frame, (kx, ky), 4, color_kpt[kpt_name], -1)

                    # (Optional) label the keypoint name
                    cv2.putText(frame, kpt_name, (kx+5, ky),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, color_kpt[kpt_name], 1)

        # Write the modified frame to output video
        out.write(frame)

        frame_index += 1

    cap.release()
    out.release()
    print("Finished writing annotated video:", output_video)

In [None]:
def load_metadata(source_dir, metadata_filename):
    metadata_filePath = os.path.join(source_dir, metadata_filename)

    with open(metadata_filePath, 'r') as f:
        return json.load(f)

In [None]:
def save_metadata(output_dir, metadata_filename, metadata):
    metadata_outFilePath = os.path.join(output_dir, metadata_filename)

    with open(metadata_outFilePath, 'w') as f:
        json.dump(metadata, f, indent=4)

# Main

In [None]:
predicted_labels_dir = "/mnt/c/Users/karti/chest/CNR/projects/data/neurocig/vids/results/test1_noTrack/labels"
video_predcitionOn_path = "/mnt/c/Users/karti/chest/CNR/projects/data/neurocig/vids/processed/Gabbia2-D6-eCig(1)-pre.mp4"
img_w, img_h = get_video_resolution(video_predcitionOn_path)

# dict { frame_idx: [ { 'bbox':(x,y,w,h), 'keypoints':... }, ... ] }
detections = {}
mAnnotated_flag = False
visiblePercentage = 0.90
for predicted_label in os.listdir(predicted_labels_dir):
    if predicted_label.endswith('.txt'):
        txt_path = os.path.join(predicted_labels_dir, predicted_label)

        temp_holder = predicted_label.split('_')
        frame_index = int(temp_holder[1].split('.')[0])

        detection = yolo_txt_to_detection(txt_path, frame_index, img_w, img_h, mAnnotated_flag, visiblePercentage, ["nose", "earL", "earR", "tailB"])
        detections.update(detection)
        

In [None]:
tracked_detections = track_multi_mice(detections)

In [None]:
count_more = 0
count_less = 0
for i in range(len(detections)):
    if len(detections[i+1]) !=  len(tracked_detections[i+1]):
        if len(detections[i+1]) <  len(tracked_detections[i+1]):
            count_more += 1
        else:
            count_less += 1

print(f"Number of time tracked detection are more {count_more} and number of time tracked detections are less than the original detection {count_less}")

In [None]:
final_detections = {}

for frame_index, mice in detections.items():
    id_mice = {}
    
    for mouse in mice:
        for tracked_mice_id, tracked_mice in tracked_detections[frame_index].items():
            if mouse['keypoints'] == tracked_mice['keypoints']:
                id_mice[tracked_mice_id] = {'bbox': mouse['bbox_xY'], 'keypoints' : mouse['keypoints']}
                # print(mouse['keypoints'], tracked_mice_id, tracked_mice['keypoints'])
        
    final_detections[frame_index] = id_mice

In [None]:
print(len(final_detections))

In [None]:
count = 0
for i in range(len(detections)):
    if len(detections[i+1]) !=  len(final_detections[i+1]):
        count += 1
        # print(i+1, len(detections[i+1]), len(tracked_detections[i+1]))

print(count)

In [None]:
output_path = "/mnt/c/Users/karti/chest/CNR/projects/data/neurocig/vids/results"

In [None]:
save_metadata(output_path, 'tracked_annotations.json', final_detections)
final_detections = load_metadata(output_path, 'tracked_annotations.json')

In [None]:
# Usage:

FinalVideo_path = os.path.join(output_path, 'tracked_video.mp4')

color_box = {
    1 : (0, 255, 255),
    2 : (0, 255, 128),
    3 : (153, 51, 155),
    4 : (255, 255, 0),
    5 : (255, 0, 255)
}

color_kpt = {
    'nose' : (153, 204, 255),
    'earL' : (255, 182, 78),
    'earR' : (255, 102, 102),
    'tailB' : (255, 153, 204)
}

discard = (False, [])

overlay_annotations_on_video(video_predcitionOn_path, final_detections, color_box, color_kpt, FinalVideo_path, discard)