Image registration tool using cv2  

Possible future features:
* Preprocess image if needed (e.g. invert pixel value)
* change sift parameters
* Iterating over Lowe's Ratio Test - Maybe values between 0.6-0.8 - lower is more selective.
* Change RANSAC reprojection threshold - a smaller value will increase precision (fewer incorrect matches will be accepted), but may also decrease recall (fewer total matches will be found).
* Use both mean distance and num_inliers to calculate quality (currently only mean_distance is used)
* Maybe change the random seed to get better registration results (and loop)

In [7]:
import cv2
import numpy as np
import os
import logging
import matplotlib.pyplot as plt
try: 
    import aicspylibczi
except ImportError:
    logging.warning("CZI can't be read, as aicspylibczi is not installed")


In [8]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.info("test")

INFO:root:test


In [9]:
def read_czi_image(czi, channel=0):
    ## need to have scale_factor as parameter - scale_factor=0.25
    full_channel = czi.read_mosaic(C=1)[0]
    return full_channel

In [10]:
def read_czi(image_path):
    czi = aicspylibczi.CziFile(image_path)
    print_czi_metadata(czi)
    image = read_czi_image(czi)
    return image

In [13]:
def print_czi_metadata(czi):
    logging.info('CZI metadata:')
    logging.info(czi.dims)
    logging.info(czi.size)
    logging.info(czi.get_dims_shape())
    logging.info(czi.is_mosaic())
    # for i in czi.meta.iter():
    #     logging.info(i, i.text)

In [16]:
def read_image(image_path):
    if os.path.splitext(image_path)[-1] == '.czi':
        image = read_czi(image_path)
    else:
        image = cv2.imread(image_path)
        
    logging.info(f'{image_path} image shape: {image.shape}')
    return image

In [17]:
def get_image_mid_part(image):
    
    height = min(image.shape[0], 1000)
    width = min(image.shape[1], 1000)
    
    image = image[image.shape[0]//2-height//2:image.shape[0]//2+height//2,
                  image.shape[1]//2-width//2:image.shape[1]//2+width//2]
    
    return image

In [20]:
def log_image_metadata(image, is_preprocessed=True):
    
    processed_str = ' after preprocessing' if is_preprocessed else ''
    
    logging.info(f'image shape{processed_str}: {image.shape}')
    logging.info(f'image dtype{processed_str}: {image.dtype}')
    logging.info(f'image range{processed_str}: {np.min(image), np.max(image)}')

In [19]:
def preprocess_image(image):
            
    if image.ndim==3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
    image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
    #image = cv2.equalizeHist(image)
    
    return image

In [21]:
def read_and_preprocess_images(source_image_path, destination_image_path, is_preprocess=True):
    
    src_img = read_image(source_image_path)
    dst_img = read_image(destination_image_path)
    
    log_image_metadata(image, is_preprocessed=False)
    
    if is_preprocess:
        src_img = preprocess_image(src_img)
        dst_img = preprocess_image(dst_img)
        log_image_metadata(image)
    
    return src_img, dst_img

In [None]:
def def_flann_matches():
    # Use FLANN (Fast Library for Approximate Nearest Neighbors) to find matches
    FLANN_INDEX_KDTREE = 1
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    search_params = dict(checks=50)

    flann = cv2.FlannBasedMatcher(index_params, search_params)
    
    return flann

In [None]:
# def def_bf_matcher():
#     bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
#     return bf

In [None]:
## might want to iterate over the Lowe's ratio for better registration results
def get_good_matches(matches):
    # Store all the good matches using Lowe's ratio test
    good_matches = []
    for m, n in matches:
        if m.distance < 0.8 * n.distance:
            good_matches.append(m)
            
    return good_matches

In [None]:
def get_pts_from_matches(good_matches, kp1, kp2):
    if len(good_matches) > 10:
        src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
        dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
    else:
        raise AssertionError("Not enough matches found for image registration.")
        
    return src_pts, dst_pts

In [None]:
def plot_good_matches(src_img, dst_img, src_pts, dst_pts):
    fig, ax = plt.subplots(1,2,figsize=(20,10))
    
    ax[0].imshow(src_img)
    ax[0].axis('off')
    ax[0].plot(*src_pts.T, marker='v', color="red")
    for i,loc in enumerate(src_pts):
        ax[0].annotate(i, *loc)
        
    ax[1].imshow(dst_img)
    ax[1].axis('off')
    ax[1].plot(*dst_pts.T, marker='v', color="red")
    for i,loc in enumerate(dst_pts):
        ax[1].annotate(i, *loc)

In [None]:
def test_registration_by_kpts_mean_sqrt(src_pts, dst_pts, H, mask):
    
    # Map the keypoints in the source image to the destination image using the homography
    mapped_src_pts = cv2.perspectiveTransform(src_pts, H)

    # Calculate the Euclidean distances between the mapped source keypoints and the destination keypoints
    distances = np.sqrt(np.sum((mapped_src_pts - dst_pts)**2, axis=2))
    
    inlier_distances = distances[mask == 1]

    # Calculate the mean of the distances
    mean_distance = np.mean(inlier_distances)

    logging.info(f'Mean Euclidean distance between mapped source keypoints and destination keypoints: {mean_distance}')
    
    return mean_distance

In [None]:
def test_registration_by_num_inliers(mask):
    
    # Convert mask to a list
    mask = mask.ravel().tolist()
    
    # Count the number of inliers and outliers
    num_inliers = mask.count(1)
    num_outliers = mask.count(0)

    logging.info(f'RANSAC number of Inliers: {num_inliers}')
    logging.info(f'RANSAC number of Outliers: {num_outliers}')
    
    return num_inliers

In [None]:
def check_degenerate(points):
    # Compute the standard deviation of the points
    std_dev = np.std(points, axis=0)

    # If the standard deviation is below a threshold (here 1.0), 
    # it might indicate that the points are in a degenerate configuration.
    if np.any(std_dev < 1.0):
        print("Warning: points may be in a degenerate configuration.")
        return True
    return False

In [None]:
def check_transformation_quality(H, src_img):
    # Create a grid of points in the source image
    height, width = src_img.shape[:2]
    x = np.linspace(0, width, num=50)
    y = np.linspace(0, height, num=50)
    xv, yv = np.meshgrid(x, y)
    points = np.column_stack([xv.ravel(), yv.ravel()])

    # Transform the points using the homography
    points_t = cv2.perspectiveTransform(points.reshape(-1, 1, 2), H)

    # Compute the standard deviation of the transformed points
    std_dev = np.std(points_t, axis=0)

    # If the standard deviation is very small or very large, it might indicate a poor transformation.
    if np.any(std_dev < 1.0) or np.any(std_dev > max(height, width)):
        print("Warning: transformation may be poor.")
        return False

    return True


In [None]:
def test_registration(src_pts, dst_pts, H, mask):
    mean_distance = test_registration_by_kpts_mean_sqrt(src_pts, dst_pts, H, mask)
    num_inliers = test_registration_by_num_inliers(mask)
    
    return mean_distance, num_inliers

In [None]:
def get_sift_based_registration_matrix(src_img, dst_img):
    
    # Find keypoints and descriptors for the images
    kp1, des1 = sift.detectAndCompute(src_img, None)
    kp2, des2 = sift.detectAndCompute(dst_img, None)
    
    matches = flann.knnMatch(des1, des2, k=2)
    # good_matches = bf.match(des1, des2)
    # good_matches = sorted(good_matches, key = lambda x:x.distance)

    #### THIS SEEMS TO BE THE LINE WITH THE INCONSISTANCY BETWEEN RUNS!!!!!!
    good_matches = get_good_matches(matches)
    
    src_pts, dst_pts = get_pts_from_matches(good_matches, kp1, kp2)
    #plot_good_matches(src_img, dst_img, src_pts, dst_pts)
    
    # Call the function before computing homography
    if check_degenerate(src_pts) or check_degenerate(dst_pts):
        print("Degenerate configuration detected, handling this case separately...")
        # Add your handling code here
        
    H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5)
    
    mean_distance, num_inliers = test_registration(src_pts, dst_pts, H, mask)
    check_transformation_quality(H, src_img)
    
    return H, mean_distance

In [None]:
def warp_image(dst_img, src_img, H):
    # Warp the source image using the homography matrix
    #height, width, _ = dst_img.shape
    height, width = dst_img.shape
    registered_img = cv2.warpPerspective(src_img, H, (width, height))
    
    return registered_img

In [None]:
def save_results(output_path, h, registered_image):
    os.makedirs(output_path, exist_ok=True)
    cv2.imwrite(os.path.join(output_path, 'registered_image.png'), registered_image)
    np.save(os.path.join(output_path, 'registration_matrix.npy'), h)

In [None]:
def plot_result(source_image, destination_image, registered_image):
    fig, ax = plt.subplots(2,2,figsize=(20,20))
    ax[0,0].imshow(source_image, cmap='gray')
    ax[0,0].set_title('source image')
    ax[0,0].axis('off')
    ax[0,1].imshow(destination_image, cmap='gray')
    ax[0,1].set_title('destination image')
    ax[0,1].axis('off')
    ax[1,0].imshow(registered_image, cmap='gray')
    ax[1,0].set_title('registered image')
    ax[1,0].axis('off')
    ax[1,1].imshow(destination_image, cmap='gray')
    ax[1,1].imshow(registered_image, alpha=0.3, cmap='hot_r')
    ax[1,1].set_title('overlay')
    ax[1,1].axis('off')

In [None]:
## add a function for scale factor

In [None]:
if __name__ == "__main__":
    
    sift = cv2.SIFT_create()
    # SIFT params: (int nfeatures=0, int nOctaveLayers=3, double contrastThreshold=0.04, double edgeThreshold=10, double sigma=1.6)
    #3>nOctaveLayers>5
    flann = def_flann_matches()
    
    # source_img_path = "../akalin_assay_coregistration/data/morphology.png"
    # destination_img_path = "../akalin_assay_coregistration/data/tissue_lowres_image.png"
    
    source_img_path = '../burgstaller_channel_alignment/data/CPC_Pestoni_lung_24-01-2023_RUN1_slide0.czi'
    destination_img_path = '../burgstaller_channel_alignment/data/CPC_Pestoni_RUN2_Slide0.czi'

    output_path = "registration_output"

    source_image, destination_image = read_and_preprocess_images(source_img_path, destination_img_path)
    
    # Try different rotations of the image:
    lowest_mean_distance = np.inf

    #for i in range(4):
    for i in range(1):
        cv2.setRNGSeed(0)
        logging.info(f'Source image Rotated - {90*(i+1)%360}')
        #source_image = cv2.rotate(source_image, cv2.ROTATE_90_CLOCKWISE)
        h, mean_distance = get_sift_based_registration_matrix(source_image, destination_image)
        if mean_distance<lowest_mean_distance:
            lowest_mean_distance = mean_distance
            H = h
        logging.info('\n')
    
    registered_image = warp_image(destination_image, source_image, H)
    
    save_results(output_path, h, registered_image)
    plot_result(source_image, destination_image, registered_image)

In [None]:
# Create a scaling matrix
scale = 0.25

S = np.array([[scale, 0, 0], [0, scale, 0], [0, 0, 1]])

# Adjust H
H_full = np.dot(np.dot(np.linalg.inv(S), H), S)

In [None]:
source_img_path = '../burgstaller_channel_alignment/data/CPC_Pestoni_lung_24-01-2023_RUN1_slide0.czi'
destination_img_path = '../burgstaller_channel_alignment/data/CPC_Pestoni_RUN2_Slide0.czi'

source_image, destination_image = read_and_preprocess_images(source_img_path, destination_img_path, is_preprocess=False)

registered_image = warp_image(destination_image, source_image, H)

In [None]:
plot_result(source_image, destination_image, registered_image)

In [None]:
fig, ax = plt.subplots(2,2,figsize=(20,20))
ax[0,0].imshow(source_image, cmap='gray')
ax[0,0].set_title('source image')
ax[0,0].axis('off')
ax[0,1].imshow(destination_image, cmap='gray')
ax[0,1].set_title('destination image')
ax[0,1].axis('off')
ax[1,0].imshow(registered_image, cmap='gray')
ax[1,0].set_title('registered image')
ax[1,0].axis('off')
ax[1,1].imshow(destination_image, cmap='gray')
ax[1,1].imshow(source_image, alpha=0.3, cmap='hot_r')
ax[1,1].set_title('overlay')
ax[1,1].axis('off')