In [None]:
#-- Matching Same Object --

In [1]:
from IPython import display

In [2]:
#-- Install ultralytics ------------------------------------------------------------------------------------------
!pip install ultralytics

display.clear_output()

import ultralytics
ultralytics.checks()
#-----------------------------------------------------------------------------------------------------------------

Ultralytics 8.3.36 🚀 Python-3.10.14 torch-2.4.0 CUDA:0 (Tesla T4, 15095MiB)
Setup complete ✅ (4 CPUs, 31.4 GB RAM, 5933.9/8062.4 GB disk)


In [3]:
#-- Imports ----------------------------------------------------------------------------------------------------
from ultralytics import YOLO
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.models import Model
from collections import defaultdict
import cv2
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import os
from datetime import datetime, timedelta
import shutil
#-----------------------------------------------------------------------------------------------------------------

In [5]:
#-- Initialize --------------------------------------------------------------------------------------------------
out_dir = '/kaggle/working/'
detection_weights_file = '/kaggle/input/yolo11-11frozen-13/model_11_frozen_epoch_60/train/weights/best.pt'

drone_files = ['/kaggle/input/drone-dataset-p1/v_5.mp4',
              '/kaggle/input/drone-dataset-p2/v_8.mp4',
              '/kaggle/input/drone-detection-test-videos-1/drone_video (1).mp4',
              # '/kaggle/input/novin-data/Novin_Dataset/f2.part2.mp4',
              '/kaggle/input/sample-videos-detecting-and-matching-objs-1/sample_video_drone (5).mp4',
              '/kaggle/input/video-drone-bird-1/Untitled-13.mp4']

results_dir = out_dir + 'results/'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

AREA_THRESHOLD = 5
DISTANCE_THRESHOLD = 50
SIMILARITY_THRESHOLD = 0.6
CROP_PADDING = 10
TIME_THRESHOLD = 60
NUM_TRACK_THRESHOLD = 30
#-----------------------------------------------------------------------------------------------------------------

In [6]:
#-- Set Detection Model ------------------------------------------------------------------------------------------
model = YOLO(detection_weights_file)  
#-----------------------------------------------------------------------------------------------------------------

In [None]:
#-- Set Similarity Measure Model ---------------------------------------------------------------------------------
similarity_base_model = ResNet50(weights='imagenet')

#-- Use the second-last layer for embeddings --
similarity_model = Model(inputs=similarity_base_model.input,
                         outputs=similarity_base_model.layers[-2].output)  
#-----------------------------------------------------------------------------------------------------------------

In [None]:
#-- Function to Preprocess Image for Similarity Measure ---------------------------------------------------------
def preprocess_image(image, target_size=(224, 224)):
    
    image = cv2.resize(image, target_size) 
    image = np.expand_dims(image, axis=0)  #-- Add batch dimension
    image = preprocess_input(image)  #-- Normalize for ResNet
    return image
#-----------------------------------------------------------------------------------------------------------------

In [None]:
#-- Function to Calculate Similarity -----------------------------------------------------------------------------
def compare_similarity_images(image1, image2):   
    
    #-- Preprocess images --
    img1 = preprocess_image(image1)
    img2 = preprocess_image(image2)
    
    #-- Extract features --
    embedding1 = similarity_model.predict(img1)
    embedding2 = similarity_model.predict(img2)

    #-- Compute cosine similarity --
    similarity_score = cosine_similarity(embedding1, embedding2)[0][0]     

    return similarity_score
#-----------------------------------------------------------------------------------------------------------------

In [None]:
#-- Function to Match Detected Objects ---------------------------------------------------------------------------
def match_object(track_id, track_box, track_image, track_time, last_tracked_objects):   
    
    plt.imshow(track_image)
    plt.title(f'track_object - id:{track_id}')
    plt.axis('off')  
    plt.show()    
    
    track_center_x, track_center_y, track_w, track_h = track_box     
    
    distance_match = {}
    similarity_match = {}
    
    matched_id = None
    
    for obj_id, (obj_box, obj_img, obj_time) in last_tracked_objects.items():
        
        print(f'####################### {obj_id} #####################')
        plt.imshow(obj_img)
        plt.title(f'object- id:{obj_id}')
        plt.axis('off')  
        plt.show()    
        
        time_difference = abs(track_time - obj_time)
        print(f'-------- time_difference: {time_difference} -------------')
        if time_difference > timedelta(seconds=TIME_THRESHOLD):            
            continue
        
        similarity_score = compare_similarity_images(track_image, obj_img)
        print(f'-------- similarity_score: {similarity_score} -------------')
        if similarity_score < SIMILARITY_THRESHOLD:            
            continue
        
        x_center, y_center, w, h = obj_box     
        
        a_track = track_w * track_h
        a_obj = w *h        
        if a_track>a_obj:
            a_ratio = a_track/a_obj
        else:
            a_ratio = a_obj/a_track
        
        print(f'-------- a_ratio: {a_ratio} -------------')
        if a_ratio > AREA_THRESHOLD:            
            continue         
        
        distance = np.sqrt((track_center_x - x_center)**2 + (track_center_y - y_center)**2)      
        print(f'-------- distance: {distance} -------------')
        if distance <= DISTANCE_THRESHOLD:            
            distance_match[obj_id] = distance
        else:
            similarity_match[obj_id] = similarity_score
    
    if len(distance_match)!=0:
        matched_id = min(distance_match, key=distance_match.get) 
    elif len(similarity_match)!=0:
        matched_id = min(similarity_match, key=similarity_match.get)       
    
    print(f'-------- matched_id: {matched_id} -------------')
    return matched_id

      
#-----------------------------------------------------------------------------------------------------------------

In [None]:
def crop_object(frame, box, padding=CROP_PADDING):
    
    frame_height, frame_width = frame.shape[:2]    
    center_x, center_y, w, h = box  
    
    top_left_x = int(max(center_x - w // 2 - padding, 0))
    top_left_y = int(max(center_y - h // 2 - padding, 0))
    bottom_right_x = int(min(center_x + w // 2 + padding, frame_width))
    bottom_right_y = int(min(center_y + h // 2 + padding, frame_height))
    
    cropped_object = frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x].copy()
    
    return cropped_object


In [None]:
#-- Run ----------------------------------------------------------------------------------------------------------
for video_file in drone_files:
    
    #-- get video name --
    index = video_file.rfind('/')      
    video_name = video_file[index + 1:] 
    
    #-- set output file --
    out_video_name = 'out_' + video_name    
    output_path = results_dir + out_video_name
    
    print(f'=== Processing {video_name} ================================')
    
    cap = cv2.VideoCapture(video_file)
    
    #-- get video properties --
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    #-- set video writer --
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")  
    out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
    
    track_history = defaultdict(lambda: []) #-- for tracking
    last_tracked_objects = {} #-- for matching  
    mapped_objects = {}
    frame_number = 0
    
    while cap.isOpened():    
        success, frame = cap.read()
        if success:            
            frame_number += 1
            print(f'\nframe number = {frame_number} ==================================') 
            print('track_history:', track_history.keys())
            print('last_tracked_objects:', last_tracked_objects.keys())
#             if frame_number>=900:
#                 break            
            
            #-- detect and track objects --
            results = model.track(frame,    
                                  tracker = 'bytetrack.yaml',
                                  persist=True,
                                  show = False)

            #-- Check if there are any detections --
            if results[0].boxes is not None and results[0].boxes.xywh is not None:            
                boxes = results[0].boxes.xywh.cpu()
                track_ids = results[0].boxes.id
                if track_ids is not None:
                    track_ids = track_ids.int().cpu().tolist()       
                                
                    for box, track_id in zip(boxes, track_ids):     

                        #-- Crop the object from the frame --
                        cropped_object = crop_object(frame, box)
                        detection_time = datetime.now()
                        
                        #-- Check if this is the first detected object --
                        if len(last_tracked_objects)==0:
                            last_tracked_objects[track_id] = (box, cropped_object, detection_time)
                        else:
                            #-- Check if its not new detected object --
                            if track_id in last_tracked_objects:
                                last_tracked_objects[track_id] = (box, cropped_object, detection_time)                   
                                
                            else:   
                                if track_id in mapped_objects:
                                    matched_id =  mapped_objects[track_id]
                                else:
                                    matched_id = match_object(track_id, box, cropped_object, detection_time, last_tracked_objects)
                                if matched_id is not None:
                                    mapped_objects[track_id] = matched_id
                                    track_id = matched_id                                    
                                last_tracked_objects[track_id] = (box , cropped_object, detection_time)                      
                                
                        #-- track --                  
                        annotated_frame = frame
                        x, y, w, h = box
                        track = track_history[track_id]
                        track.append((float(x), float(y)))  #-- x, y center point
                        if len(track) > NUM_TRACK_THRESHOLD:  #-- retain NUM_TRACK_THRESHOLD tracks
                            track.pop(0)

                        #-- Draw the tracking lines --
                        points = np.array(track, dtype=np.int32).reshape((-1, 1, 2))
                        cv2.polylines(annotated_frame, [points], isClosed=False, color=(0, 0, 255), thickness=4)
                        
                        #-- Draw the bounding box --
                        top_left = (int(x - w / 2), int(y - h / 2))
                        bottom_right = (int(x + w / 2), int(y + h / 2))
                        cv2.rectangle(annotated_frame, top_left, bottom_right, (255, 0, 0), 2)  #-- Blue bounding box

                        #-- Put the ID text --
                        text_position = (int(x - w / 2), int(y - h / 2) - 10)
                        cv2.putText(annotated_frame, f'ID: {track_id}', text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
                else:
                    annotated_frame = frame  #-- If no IDs, use original frame

            else:
                annotated_frame = frame  #-- If no boxes, use original frame

            out.write(annotated_frame)

        else:
            #-- Break the loop if the end of the video is reached --
            break

    #-- Release the video capture object and close the display window --
    cap.release()
    out.release()   
    #display.clear_output()
    

In [None]:
# zip_results = "results"
# shutil.make_archive(zip_results, 'zip', results_dir)