In [64]:
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import cv2
import os

In [49]:
FILENAME = "test_corner.mp4"
MAX_FRAMES = 100
INTERVAL = 10
MAX_MATCH_DISTANCE = 40

In [62]:
def extract_keypoints(video):
    # Create a VideoCapture object to read the video file
    cap = cv2.VideoCapture(video)
    # Create an ORB object
    orb = cv2.ORB_create()
    # Extract all keypoints and descriptors by frame
    frame_kpt, frame_des = [], []
    video_frames = []
    k = 1
    # Loop through the video frames
    with tqdm(total=MAX_FRAMES, colour="blue") as pbar:
        while cap.isOpened() and k <= MAX_FRAMES:
            # Read a frame from the video
            ret, frame = cap.read()
            # Check if the frame was successfully read
            if not ret:
                continue
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            kpt, des = orb.detectAndCompute(gray, None)
            if des is None:
                print("No keypoints/descriptors in frame ", k)
                continue
            pbar.update(1)
            frame_kpt.append(kpt)
            frame_des.append(des)
            video_frames.append(frame)
            k += 1
            # Wait for Esc key to stop
            if cv2.waitKey(1) == 27:
                # De-allocate any associated memory usage
                cv2.destroyAllWindows()
                cap.release()
                break
        cap.release()
        return frame_kpt, frame_des, video_frames


def compute_all_matches(frame_des, frame_kpt):
    # Create a Brute Force Matcher object
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
    # Compute all matches
    all_matches = []
    matched_keypoints = []
    matched_descriptors = []
    for i in range(len(frame_des) - 1):
        matches = bf.match(frame_des[i], frame_des[i + 1])
        matches = [match for match in matches if match.distance < MAX_MATCH_DISTANCE]
        all_matches.append(matches)
        # Get matched keypoints
        matched_keypoints.append([[frame_kpt[i][match.queryIdx] for match in matches]])
        matched_descriptors.append(
            [[frame_des[i][match.queryIdx] for match in matches]]
        )
    if i == len(frame_des) - 2:
        matched_keypoints.append(
            [[frame_kpt[i + 1][match.trainIdx] for match in matches]]
        )
        matched_descriptors.append(
            [[frame_des[i + 1][match.trainIdx] for match in matches]]
        )
    return all_matches


frame_kpt, frame_des, video_frames = extract_keypoints(FILENAME)
all_matches = []
for i in tqdm(range(0, len(frame_kpt), 10), colour="blue"):
    all_matches.append(compute_all_matches(frame_des[0 + i : 10 + i], frame_kpt[0 + i : 10 + i]))

100%|[34m██████████[0m| 100/100 [00:01<00:00, 64.51it/s]
100%|[34m██████████[0m| 10/10 [00:00<00:00, 149.16it/s]


In [65]:
def compare_matches(all_matches: np.ndarray) -> np.ndarray:
    """
    description: Compare the matches between each pair of intervals.
    param       {np.ndarray} all_matches: DataFrame containing the matches between each pair of intervals.
    return      {np.ndarray}: A list containing the matches between each pair of intervals.
    """
    # List to store the continuous matches and the terminal matches
    continues_keypoints = []
    terminated_keypoints = []
    # Extract pairs of matching keypoints' indices from the matches
    with tqdm(total=len(all_matches)) as pbar:
        for i, matches in enumerate(all_matches):
            pbar.update(1)
            if not matches:
                continue
            # Extract the indices of the matching keypoints
            match_pair_1, match_pair_2 = zip(*matches)
            # Extract the indices of the matching keypoints from the last interval
            last_matches = [x[1] for x in continues_keypoints]
            for j in range(10):
                if j in match_pair_1:
                    if j in last_matches:
                        index = last_matches.index(j)
                        continues_keypoints[index][0].append(match_pair_2[
                            match_pair_1.index(j)
                        ])
                        continues_keypoints[index][2] = i + 2
                    else:
                        continues_keypoints.append(
                            [[j, match_pair_2[match_pair_1.index(j)]], i + 1, i + 2]
                        )
                else:
                    if j in last_matches:
                        terminated_keypoints.append(
                            continues_keypoints[last_matches.index(j)]
                        )

            # Remove the keypoints in consecutive matches that has terminated
            continues_keypoints = [
                x for x in continues_keypoints if x not in terminated_keypoints
            ]
    terminated_keypoints.extend(continues_keypoints)
    return terminated_keypoints

debug_data = [
    [(1, 3), (2, 4), (3, 5), (5, 9)],
    [(3, 4), (4, 5), (8, 9)],
    [(1, 2), (2, 3), (4, 5), (5, 6)],
    [(2, 4), (4, 5), (5, 6), (6, 7)],
    [(1, 1), (2, 3), (3, 5), (4, 6), (6, 9)],
]
print(compare_matches(debug_data))

100%|██████████| 5/5 [00:00<00:00, 65741.44it/s]

[[[1, 3], 1, 2], [[2, 4, 2], 1, 4], [[4, 5], 3, 4], [[3, 5, 1], 1, 6], [[5, 9], 1, 2], [[3, 4, 3, 4, 3], 2, 6], [[4, 5], 2, 3], [[8, 9], 2, 3], [[5, 6, 5], 3, 6], [[4, 5, 6], 4, 6], [[5, 6], 4, 5], [[6, 7], 4, 5], [[6, 9], 5, 6]]





In [69]:
def compare_matches(all_matches: np.ndarray) -> np.ndarray:
    """
    description: Compare the matches between each pair of intervals.
    param       {np.ndarray} all_matches: DataFrame containing the matches between each pair of intervals.
    return      {np.ndarray}: A list containing the matches between each pair of intervals.
    """
    # List to store the continuous matches and the terminal matches
    continues_keypoints = []
    terminated_keypoints = []
    # Extract pairs of matching keypoints' indices from the matches
    with tqdm(total=len(all_matches)) as pbar:
        for i, matches in enumerate(all_matches):
            pbar.update(1)
            if not matches:
                continue
            # Extract the indices of the matching keypoints
            match_pair_1, match_pair_2 = zip(*matches)
            # Extract the indices of the matching keypoints from the last interval
            last_matches = [x[1] for x in continues_keypoints]
            for j in range(10):
                if j in match_pair_1:
                    if j in last_matches:
                        index = last_matches.index(j)
                        continues_keypoints[index][1] = match_pair_2[
                            match_pair_1.index(j)
                        ]
                        continues_keypoints[index][3] = i + 2
                    else:
                        continues_keypoints.append(
                            [j, match_pair_2[match_pair_1.index(j)], i + 1, i + 2]
                        )
                else:
                    if j in last_matches:
                        terminated_keypoints.append(
                            continues_keypoints[last_matches.index(j)]
                        )
                    else:
                        terminated_keypoints.append([j, j, 1 + i, 1 + i])

            # Remove the keypoints in consecutive matches that has terminated
            continues_keypoints = [
                x for x in continues_keypoints if x not in terminated_keypoints
            ]
    terminated_keypoints.extend(continues_keypoints)
    return terminated_keypoints

debug_data = [
    [(1, 3), (2, 4), (3, 5), (5, 9)],
    [(3, 4), (4, 5), (8, 9)],
    [(1, 2), (2, 3), (4, 5), (5, 6)],
    [(2, 4), (4, 5), (5, 6), (6, 7)],
    [(1, 1), (2, 3), (3, 5), (4, 6), (6, 9)],
]
terminated_keypoints = compare_matches(debug_data)
for tk in terminated_keypoints:
    if tk[3] - tk[2] > 0:
        print(tk)


100%|██████████| 5/5 [00:00<00:00, 64133.09it/s]

[3, 5, 1, 2]
[5, 9, 1, 2]
[8, 9, 2, 3]
[2, 3, 3, 4]
[4, 5, 4, 5]
[2, 7, 1, 5]
[1, 9, 1, 6]
[1, 6, 3, 6]
[1, 1, 5, 6]
[2, 3, 5, 6]
[3, 5, 5, 6]





In [44]:
def plot_keypoints_traces(frame, kpts):
    plt.figure(figsize=(10, 10))
    plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    for kpt in kpts:
        plt.plot(kpt[0], kpt[1], 'ro', markersize=2)  # red dot for each keypoint
    plt.title("Keypoints' traces on the first frame of 100th interval")
    plt.show()

In [45]:
import time

for i in tqdm(range(10), colour='blue'):
    time.sleep(0.1)

100%|[34m██████████[0m| 10/10 [00:01<00:00,  9.54it/s]
