In [205]:
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt

import re
import os
from collections import defaultdict

In [206]:
def preprocess(img):
    img = cv.GaussianBlur(img, (3, 3), 0)
    img = cv.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(img)
    return img

In [207]:
def keypoints(I1, I2):
    sift = cv.SIFT_create()  # or cv.AKAZE_create() / cv.ORB_create() for speed
    kp1, des1 = sift.detectAndCompute(I1, None)
    kp2, des2 = sift.detectAndCompute(I2, None)
    return (kp1, des1), (kp2, des2)

In [208]:
def match(des1, des2):
    bf = cv.BFMatcher()
    matches = bf.knnMatch(des1, des2, k=2)

    good = []
    for m, n in matches:
        if m.distance < 0.75 * n.distance:  # Lowe's ratio test
            good.append(m)

    # print(f"Found {len(good)} good matches out of {len(matches)} total")
    return good

In [209]:
def ransac(kp1, kp2, good):
        pts1 = np.float32([kp1[m.queryIdx].pt for m in good])
        pts2 = np.float32([kp2[m.trainIdx].pt for m in good])

        M, inliers = cv.estimateAffinePartial2D(
            pts2, pts1,
            method=cv.RANSAC,
            ransacReprojThreshold=3.0,
            maxIters=2000,
            confidence=0.99
        )

        # print("Affine matrix:\n", M)
        return M, inliers

In [210]:
def display(img1_full, img1_gray, img2_full, img2_gray, show=False):

    blue_tinted_image = np.zeros_like(img1_full)
    blue_tinted_image[:, :, 0] = img1_gray
    red_tinted_image = np.zeros_like(img2_full)
    red_tinted_image[:, :, 2] = img2_gray

    overlay = cv.addWeighted(blue_tinted_image, 0.5, red_tinted_image, 0.5, 0)

    if show:
        plt.figure(figsize=(24,12))
        plt.subplot(1,1,1); plt.imshow(overlay); plt.title("Overlay after alignment")
        plt.show()

    return overlay

In [211]:
def register(img1_full, img2_full):

    img1_gray = cv.cvtColor(img1_full, cv.COLOR_BGR2GRAY)
    img2_gray = cv.cvtColor(img2_full, cv.COLOR_BGR2GRAY)

    I1, I2 = preprocess(img1_gray), preprocess(img2_gray)

    # --- 1. Detect and describe keypoints ---
    (kp1, des1), (kp2, des2) = keypoints(I1, I2)

    # --- 2. Match descriptors with ratio test ---
    if kp1 is not None and kp2 is not None and des1 is not None and des2 is not None:
        good = match(des1, des2)

        # --- 3. Estimate affine transform using RANSAC ---
        if len(good) >= 3:  # need at least 3 points for affine
            M, inliers = ransac(kp1, kp2, good)

            # If you need to enforce only translation, you can extract the translation components
            # from the estimated matrix and create a new purely translational matrix.
            # The translation components are in the last column of the affine matrix.
            # tx = M[0, 2]
            # ty = M[1, 2]

            # Create a purely translational matrix
            # translation_matrix = np.array([[1, 0, tx],
            #                                [0, 1, ty]], dtype=np.float32)

            # print("\nPurely Translational Matrix:")
            # print(translation_matrix)

            # --- 4. Warp moving image ---
            aligned_gray = cv.warpAffine(
                img2_gray, M, (img2_gray.shape[1], img2_gray.shape[0]),
                flags=cv.INTER_LINEAR
            )
            aligned_full = cv.warpAffine(
                img2_full, M, (img2_full.shape[1], img2_full.shape[0]),
                flags=cv.INTER_LINEAR
            )
            # aligned_gray = cv.warpAffine(img2_gray, translation_matrix, (img2_gray.shape[1], img2_gray.shape[0]))
            # aligned_full = cv.warpAffine(img2_full, translation_matrix, (img2_full.shape[1], img2_full.shape[0]))

            # --- 5. Quick visual check ---
            overlay = display(img1_full, img1_gray, aligned_full, aligned_gray)
            return aligned_full, overlay

        else:
            print("Not enough good matches for reliable registration.")
    else:
        print("Not enough descriptors for reliable registration.")

In [212]:
def pipeline(fname1, fname2):
    img1 = cv.imread(fname1)    # reference (earlier)
    img2 = cv.imread(fname2)      # moving (later)

    result = register(img1, img2)

    if result is not None:
        aligned_full, overlay = result
        filename_without_extension, file_extension = os.path.splitext(fname2)
        cv.imwrite(filename_without_extension + "_sift.png", aligned_full)
        cv.imwrite(filename_without_extension + "_sift_overlay.png", overlay)

In [213]:
def group_and_order_filenames(filenames, maxTP, maxLevel):
    grouped_files = defaultdict(lambda: defaultdict(lambda: [None] * maxTP))
    pattern = r'^(?P<plant>[^_]+)_(?P<tube>\d+)_(?P<level>\d+)_(?P<date>\d{4}-\d{2}-\d{2})_TP(?P<timepoint>\d+)\.png$'
    #plant_tube_depth_yyyy-mm-dd_TP#

    for fname in filenames:

        match = re.match(pattern, fname)
        if match:
            plant = match.group('plant')
            tube = int(match.group('tube'))
            level = int(match.group('level'))
            date = match.group('date')
            timepoint = int(match.group('timepoint'))
            if 1 <= timepoint <= maxTP:
                grouped_files[tube][level - 1][timepoint - 1] = fname

    return grouped_files

In [219]:
imgfilelist = [f for f in os.listdir("slu_data") if f.endswith(".png")]
print(f"Found {len(imgfilelist)} image files")

imgfilegroups = group_and_order_filenames(imgfilelist, 12, 7)
print(f"Found {len(imgfilelist)} image groups")

Found 83 image files
Found 83 image groups


In [220]:
for tube, depths in imgfilegroups.items():
    for depth, tp_files in depths.items():
        # for i in range(0, len(tp_files) - 1):
        #     pipeline("slu_data/" + tp_files[i], "slu_data/" + tp_files[i + 1])
        non_none_files = filter(None, tp_files)
        my_iterator = iter(non_none_files)
        try:
            next_item = next(my_iterator)
            while next_item:
                current_item = next_item
                next_item = next(my_iterator)
                if current_item and next_item:
                    current_without_extension, current_extension = os.path.splitext(current_item)
                    if os.path.exists("slu_data/" + current_without_extension + "_sift.png"):
                        pipeline("slu_data/" + current_without_extension + "_sift.png", "slu_data/" + next_item)
                    else:
                        pipeline("slu_data/" + current_item, "slu_data/" + next_item)
        except StopIteration:
            continue

Not enough good matches for reliable registration.
