# Rigid Registration
2020-07-28



In [None]:
# autoreloads imported modules when modified
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')

import cv2
from imageio import imread
import numpy as np
import matplotlib.pyplot as plt
from skimage import transform
from tqdm.notebook import tqdm
from pandas import DataFrame, read_csv, concat
import seaborn as sns

from modules.utils import parse_tif_dir, normalize_image
from modules.registration import register_images_adv

from os.path import join
from os import makedirs

## Variables to set before running the loop

In [None]:
# set variables for run

# quench must be set to False
filepaths = parse_tif_dir('/media/Registration/Datasets/tonsils/', quench=False)

# ft_extractor = cv2.AKAZE_create()
# ft_extractor = cv2.KAZE_create(extended=True)
ft_extractor = cv2.xfeatures2d.SURF_create(hessianThreshold=1500)

matcher = cv2.BFMatcher(normType=cv2.NORM_L2, crossCheck=True)

# save path
save_dir = '/mnt/RigidRegistrationFigures/tonsils_SURF_BF-Matcher'
makedirs(save_dir, exist_ok=True)

target_round = 1
channels = [1, 2]  # channels to test out

data = {'target_round': target_round, 'moving_round': [], 'channel': [], 'raw_error': [], 'registered_error': [],
        'num_kpts': []}

## Run the loop

In [None]:
# add this to avoid showing figures, prevents kernel dying
%matplotlib agg

max_round = max(list(filepaths.keys()))

for channel in channels:
    print(f'Running on channel {channel}')
    
    channel_dir = join(save_dir, f'channel_{channel}')
    makedirs(channel_dir, exist_ok=True)
    
    # get target image
    target_im = normalize_image(imread(filepaths[target_round][channel]))

    # loop through the rest of the rounds
    for _round in tqdm(range(1, max_round+1), total=len(range(1, max_round+1))):
        # skip the target round
        if _round != target_round: 
            # get moving image
            moving_im = normalize_image(imread(filepaths[_round][channel]))

            save_path = join(channel_dir, f'targetRound-{target_round}_movingRound-{_round}_channel-{channel}.png')

            # run registration with saving image
            results = register_images_adv(moving_im, target_im, ft_extractor, matcher, visuals=False, 
                                    savepath=save_path)
            
            # add results to data
            data['moving_round'].append(_round)
            data['channel'].append(channel)
            data['raw_error'].append(results['error (raw)'])
            data['registered_error'].append(results['error (registered)'])
            data['num_kpts'].append(results['n_filtered_kpts'])
            
# save the data as dataframe
df = DataFrame(data)
df.to_csv(join(save_dir, 'results.csv'), index=False)

## Box Plots

In [None]:
%matplotlib inline

df = DataFrame(
    columns=['target_round', 'moving_round', 'channel', 'raw_error', 'registered_error', 'num_kpts', 'ft'])


df1 = read_csv('/mnt/RigidRegistrationFigures/normalBreast_KAZE_BF-Matcher/results.csv')
df1['ft'] = 'KAZE'
df2 = read_csv('/mnt/RigidRegistrationFigures/normalBreast_AKAZE_BF-Matcher/results.csv')
df2['ft'] = 'AKAZE'
df3 = read_csv('/mnt/RigidRegistrationFigures/normalBreast_SURF_BF-Matcher/results.csv')
df3['ft'] = 'SURF'

df = concat([df1, df2, df3])

g = sns.catplot(x='ft', y='registered_error', hue='channel', data=df, kind='swarm', height=10)

In [None]:
help(sns.catplot)