# Rigid Registration
2020-07-28



In [1]:
# autoreloads imported modules when modified
%load_ext autoreload
%autoreload 2

In [40]:
import sys
sys.path.append('..')

import cv2
from imageio import imread
import numpy as np
import matplotlib.pyplot as plt
from skimage import transform
from tqdm.notebook import tqdm
from pandas import DataFrame, read_csv, concat
import seaborn as sns
from copy import deepcopy

from modules.utils import parse_tif_dir, normalize_image
from modules.registration import register_images_adv, tre_distance, apply_transform

from os.path import join
from os import makedirs

## Variables to set before running the loop

In [None]:
# set variables for run

# quench must be set to False
filepaths = parse_tif_dir('/media/Registration/Datasets/tonsils/', quench=False)

# ft_extractor = cv2.AKAZE_create()
# ft_extractor = cv2.KAZE_create(extended=True)
ft_extractor = cv2.xfeatures2d.SURF_create(hessianThreshold=1500)

matcher = cv2.BFMatcher(normType=cv2.NORM_L2, crossCheck=True)

# save path
save_dir = '/mnt/RigidRegistrationFigures/tonsils_SURF_BF-Matcher'
makedirs(save_dir, exist_ok=True)

target_round = 1
channels = [1, 2]  # channels to test out

data = {'target_round': target_round, 'moving_round': [], 'channel': [], 'raw_error': [], 'registered_error': [],
        'num_kpts': []}

## Run the loop

In [None]:
# add this to avoid showing figures, prevents kernel dying
%matplotlib agg

max_round = max(list(filepaths.keys()))

for channel in channels:
    print(f'Running on channel {channel}')
    
    channel_dir = join(save_dir, f'channel_{channel}')
    makedirs(channel_dir, exist_ok=True)
    
    # get target image
    target_im = normalize_image(imread(filepaths[target_round][channel]))

    # loop through the rest of the rounds
    for _round in tqdm(range(1, max_round+1), total=len(range(1, max_round+1))):
        # skip the target round
        if _round != target_round: 
            # get moving image
            moving_im = normalize_image(imread(filepaths[_round][channel]))

            save_path = join(channel_dir, f'targetRound-{target_round}_movingRound-{_round}_channel-{channel}.png')

            # run registration with saving image
            results = register_images_adv(moving_im, target_im, ft_extractor, matcher, visuals=False, 
                                    savepath=save_path)
            
            # add results to data
            data['moving_round'].append(_round)
            data['channel'].append(channel)
            data['raw_error'].append(results['error (raw)'])
            data['registered_error'].append(results['error (registered)'])
            data['num_kpts'].append(results['n_filtered_kpts'])
            
# save the data as dataframe
df = DataFrame(data)
df.to_csv(join(save_dir, 'results.csv'), index=False)

## Box Plots

In [None]:
%matplotlib inline

df = DataFrame(
    columns=['target_round', 'moving_round', 'channel', 'raw_error', 'registered_error', 'num_kpts', 'ft'])


df1 = read_csv('/mnt/RigidRegistrationFigures/normalBreast_KAZE_BF-Matcher/results.csv')
df1['ft'] = 'KAZE'
df2 = read_csv('/mnt/RigidRegistrationFigures/normalBreast_AKAZE_BF-Matcher/results.csv')
df2['ft'] = 'AKAZE'
df3 = read_csv('/mnt/RigidRegistrationFigures/normalBreast_SURF_BF-Matcher/results.csv')
df3['ft'] = 'SURF'

df = concat([df1, df2, df3])

g = sns.catplot(x='ft', y='registered_error', hue='channel', data=df, kind='swarm', height=10)

In [None]:
help(sns.catplot)

## Version two of the workflow

In [53]:
# variables for the run
data_dir = '/media/Registration/Datasets/normalBreast/'

target_round = 1
dapi_round = 1

matcher = cv2.BFMatcher(normType=cv2.NORM_L2, crossCheck=True)
ft_ext = ft_extractor = cv2.AKAZE_create()
transformer = transform.SimilarityTransform()

# parse the directory
impaths = parse_tif_dir(data_dir, quench=False)

# get the list of rounds
rounds = list(impaths.keys())
rounds.sort()

# remove first index if it is zero
rounds = rounds[1:] if rounds[0] == 0 else rounds

# get target DAPI image
target_dapi_im = normalize_image(imread(impaths[target_round][dapi_round]))

# extract features from target DAPI image
# target_dapi_kpts, target_dapi_des = ft_ext.detectAndCompute(target_dapi_im, None)

data = {'target_round': target_round, 'moving_round': [], 'method': [], 'desc': [], 'error': []}

# loop through the rest of the rounds
for r in rounds:
    if r == target_round:
        continue
    
    # register DAPI channel
    
    # moving DAPI image and descriptors
#     moving_dapi_im = normalize_image(imread(impaths[r][dapi_round]))
#     moving_dapi_kpts, moving_dapi_des = ft_ext.detectAndCompute(moving_dapi_im, None)

    moving_dapi_filtered_pts, target_dapi_filtered_pts = filter_keypoints(
        moving_dapi_kpts, moving_dapi_des, target_dapi_kpts, target_dapi_des
    )
    
    # calculate the unregistered error
    h, w = target_dapi_im.shape[:2]
    unregistered_error = tre_distance(moving_dapi_filtered_pts, target_dapi_filtered_pts, h, w)
    
    # add to data
    data['moving_round'].append(r)
    data['method'].append('unregistered')
    data['desc'].append('Breast (AKAZE)')
    data['error'].append(unregistered_error)
    
    # register the DAPI channel
    registered_dapi_im = apply_transform(moving_dapi_im, target_dapi_im, moving_dapi_filtered_pts, 
                                         target_dapi_filtered_pts, transformer)[0]
    registered_dapi_im = (registered_dapi_im * 255).astype(np.uint8)
    
    # get features
    moving_reg_kpts, moving_reg_des = ft_ext.detectAndCompute(registered_dapi_im, None)
    
    moving_reg_filtered_pts, target_reg_filtered_pts = filter_keypoints(
        moving_reg_kpts, moving_reg_des, target_dapi_kpts, target_dapi_des
    )
    
    h, w = moving_dapi_im.shape[:2]
    registered_error = tre_distance(moving_reg_filtered_pts, target_reg_filtered_pts, h, w)
    data['moving_round'].append(r)
    data['method'].append('registered DAPI')
    data['desc'].append('Breast (AKAZE)')
    data['error'].append(registered_error)
    
    
    # register using second channel
    
    
    
    break
            
            

In [54]:
data

{'target_round': 1,
 'moving_round': [2, 2],
 'method': ['unregistered', 'registered DAPI'],
 'desc': ['Breast (AKAZE)', 'Breast (AKAZE)'],
 'error': [0.02590011, 0.00011252031]}

In [43]:
def filter_keypoints(moving_kpts, moving_des, target_kpts, target_des):
    matches = matcher.match(moving_des, target_des)
    moving_matched_kpts = [moving_kpts[match.queryIdx] for match in matches]
    target_matched_kpts = [target_kpts[match.trainIdx] for match in matches]
    
    # convert matched keypoints to array
    moving_matched_pts = np.float32([kpt.pt for kpt in moving_matched_kpts])  # (x,y) coords
    target_matched_pts = np.float32([kpt.pt for kpt in target_matched_kpts])  # (x,y) coords
    
    # filter keypoints using RANSAC mask
    mask = cv2.findHomography(
        moving_matched_pts, target_matched_pts, cv2.RANSAC, ransacReprojThreshold=10)[1]
    moving_filtered_kpts = [moving_matched_kpts[i] for i in np.arange(0, len(mask)) if mask[i] == [1]]
    target_filtered_kpts = [target_matched_kpts[i] for i in np.arange(0, len(mask)) if mask[i] == [1]]
    
    # convert filtered keypoints to array
    moving_filtered_pts = np.float32([kpt.pt for kpt in moving_filtered_kpts])
    target_filtered_pts = np.float32([kpt.pt for kpt in target_filtered_kpts])
    
    return moving_filtered_pts, target_filtered_pts