In [None]:
# 1-2-3-4-1

In [2]:
# import libraries
import cv2
import numpy as np
import glob
import os
import matplotlib.pyplot as plt
import time

# Function to detect keypoints and compute descriptors
def detect_and_compute(detector, descriptor, img):
    keypoints = detector.detect(img, None)
    keypoints, descriptors = descriptor.compute(img, keypoints)
    return keypoints, descriptors

# Function to read images from a directory
def read_images(query_image_path, train_images_dir):
    query_img = cv2.imread(query_image_path, cv2.IMREAD_GRAYSCALE)
    train_images = [cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) for img_path in glob.glob(os.path.join(train_images_dir, "*"))]
    return query_img, train_images

# Function to match descriptors and return the matches
def match_descriptors(descriptor_extractor, query_descriptors, train_descriptors):
    # Match descriptors.
    matches = descriptor_extractor.knnMatch(query_descriptors, train_descriptors, k=2)
    # Apply ratio test
    good_matches = []
    for m, n in matches:
        if m.distance < 0.9 * n.distance:
            good_matches.append(m)
    return good_matches

def main(query_image_path, train_images_dir):
    # Create SIFT detector and descriptor
    sift = cv2.SIFT_create()

    # Read images
    _, images = read_images(query_image_path, train_images_dir)

    # Detect keypoints and compute descriptors for all images
    keypoints_descriptors = [detect_and_compute(sift, sift, img) for img in images]

    # Initialize chains with matches between the first and second image
    matches = match_descriptors(cv2.BFMatcher(cv2.NORM_L2), keypoints_descriptors[0][1], keypoints_descriptors[1][1])
    chains = [[m.queryIdx, m.trainIdx] for m in matches]

    # Extend the chains with matches in subsequent images
    for i in range(2, len(images)):
        new_chains = []
        matches = match_descriptors(cv2.BFMatcher(cv2.NORM_L2), keypoints_descriptors[i-1][1], keypoints_descriptors[i][1])
        for chain in chains:
            for m in matches:
                if chain[-1] == m.queryIdx:
                    new_chains.append(chain + [m.trainIdx])
        chains = new_chains

    # Draw the chains
    h, w = max(img.shape[:2] for img in images)
    output_img = np.zeros((h, len(images) * w, 3), dtype="uint8")
    for i, img in enumerate(images):
        output_img[:h, i*w:(i+1)*w, :] = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

    # List of colors (red, green, blue, and yellow)
    colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
    for chain in chains:
        # Skip the keypoints in image 4 that cannot match with the keypoints in image 1
        if keypoints_descriptors[0][0][chain[0]].pt != keypoints_descriptors[3][0][chain[-1]].pt:
            continue
        for i in range(len(chain) - 1):
            pt1 = tuple(map(int, keypoints_descriptors[i][0][chain[i]].pt))
            pt2 = tuple(map(int, keypoints_descriptors[i+1][0][chain[i+1]].pt))
            pt1 = (pt1[0] + i*w, pt1[1])
            pt2 = (pt2[0] + (i+1)*w, pt2[1])
            cv2.line(output_img, pt1, pt2, colors[i % len(colors)], 2)

    # Save result image with a unique timestamp suffix
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    cv2.imwrite(f'result_{timestamp}.png', output_img)


# Usage
main('query/image1.png', 'all/')
