## Image Registration

In [4]:
import skimage
import numpy as np
import matplotlib.pyplot as plt
from skimage.registration import phase_cross_correlation
from scipy.signal import correlate
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from functools import partial

def update_beta(positions1,positions2, beta):
        
    k = correlate(positions1,positions2)
    
    threshold1 = +0.3
    threshold2 = -0.3
    
    if k > threshold1:
        beta = beta*1.1 # increase by 10%
    elif k < threshold2:
        beta = beta*0.9 #reduce by 10%
    else:
        pass # keep same value
    
    return beta

def get_illuminated_mask(probe,probe_threshold):
    mask = np.where(probe > np.max(probe)*probe_threhsold, 1, 0)
    return mask

def correct_position(probe, probe_threshold, upsampling, beta, data):
    
    obj,previous_obj, position, posiiton, index = data # unpack inputs that vary
    
    illumination_mask = get_binary_mask(probe,probe_threshold)
    
    obj = obj*illumination_mask
    previous_obj = previous_obj*illumination_mask
                         
    relative_shift, error, diffphase = phase_cross_correlation(obj, previous_obj, upsample_factor=upsampling)
    
    new_position = position + beta*relative_shift
    
    return new_position, index

def position_correction(obj,previous_obj,positions, beta, probe_threshold=0.1, upsampling=100):

    indexes = np.linspace(0,obj.shape[0]-1,obj.shape[0])
    list_of_inputs = list(zip(obj,previous_obj,positions,indexes))
    
    correct_position_partial = partial(correct_position,probe, probe_threshold, upsampling, beta)
    
    new_positions = np.zeros_like(positions)
    with ProcessPoolExecutor() as executor:
        results = list(tqdm(executor.map(correct_position,list_of_inputs),total=positions.shape[0]))
        for result in results:
            position, index = result
            new_positions[index] = position
            
    return new_positions
    
    
    
    
    
    
    