# Automated feature finding and co-registration

## Import dependencies

In [None]:
import scipy
import time
import pickle
import os
import re
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np
import h5py
from scipy import ndimage as ndi
from tqdm import tqdm
import matplotlib as mpl
mpl.rcParams['savefig.dpi']=600
pgf_with_rc_fonts = {
    "font.family": "Arial",
    "font.serif": [], 
     "font.size"   : 20,
    "axes.titlesize" : 20,
    "font.sans-serif": ["Times New Roman"], # use a specific sans-serif font
}
mpl.rcParams.update(pgf_with_rc_fonts)
import matplotlib.pyplot as plt
import cv2
from scipy import signal, stats
from sklearn.decomposition import NMF
from numpy import genfromtxt
from scipy import stats

## Helper functions

In [None]:
def cartesian_distance(a, b):
    x1, y1 = a
    x2, y2 = b
    distance = np.sqrt((x2-x1)**2+(y2-y1)**2)
    return distance

def crawldir(topdir=[], ext='sxm'):
    fn = dict()
    for root, dirs, files in os.walk(topdir):
              for name in files:
              
                if len(re.findall('\.'+ext,name)):
                    addname = os.path.join(root,name)

                    if root in fn.keys():
                        fn[root].append(addname)

                    else:
                        fn[root] = [addname]    
    return fn
        
def gamma_correlation(array, feature, window_size):
    
    size = feature.size
    
    in_array=array-np.min(array)
    in_array=in_array/np.max(in_array)
    
    arr = feature.arr
    arr=cv2.resize(arr,(2*window_size,2*window_size))

    try:
        AA=np.array(stats.spearmanr(np.ndarray.flatten(arr),np.ndarray.flatten(in_array)))
    except ValueError:

        AA=0
    return AA

## function_ls class

In [None]:
class function_ls:
    import numpy
    import warnings
    warnings.filterwarnings("ignore")
    def __init__(self,function,target,feature,epochs=500,alpha=0.1,decay=0.1,verbose=0):
        ### function fitter that performs optimization for a multivariate function
        self.target=target
        self.function=function
        self.parameters='None'
        self.guesses='None'
        self.epochs=epochs
        self.alpha=alpha
        self.decay=decay
        self.verbose=verbose
    def load_guesses(self,guesses):
        self.guesses=guesses
    def calc_rmse(self,parameters):
        return np.sum((self.function(*parameters)-self.target)**2)**0.5
    def gradient_descent_step(self):
        if self.parameters=='None':
            self.parameters=self.guesses
        for i in range(len(self.parameters)):
            altered_parameters_p=np.copy(self.parameters)
            altered_parameters_p[i]=altered_parameters_p[i]*(1+self.alpha)
            
            altered_parameters_m=np.copy(self.parameters)
            altered_parameters_m[i]=altered_parameters_m[i]*(1-self.alpha)
            
            altered_parameters=np.copy(self.parameters)
            altered_parameters[i]=altered_parameters[i]*(1)
            
            min_idx=np.argmin(np.array([self.calc_rmse(altered_parameters_m),self.calc_rmse(altered_parameters),self.calc_rmse(altered_parameters_p)]))
            self.parameters=np.array([altered_parameters_m,altered_parameters,altered_parameters_p])[min_idx]
        self.alpha=self.alpha*(1-self.decay)
        return self.parameters
    def fit(self):
        if self.parameters=='None':
            self.parameters=self.guesses
        for iteration in range(self.epochs):
           
            stage=np.copy(self.parameters)
            history=self.gradient_descent_step()
#             if all(history==stage)==True:
#                 print(history)
#                 break
            if self.verbose==1:
                print(iteration)
#         print(history)

## Feature Class

In [None]:
class feature(object):
    '''
    '''
    def __init__(self,size,template=None,rotation=0,flip=False):
        self.arr = None
        self.size = size
        self.target = None
        self._corner_array_()
        if template != None:
            if template.upper() == 'CORNER':
                self._corner_array_()
            elif template.upper() == 'SQUARE':
                self._square_array_()
            elif template.upper() == 'GAMMA':
                self._gamma_array_()
            else:
                print('Feature template not recognized, defaulting to corner.')
                self._corner_array_()
        self.rotate(rotation)
        if flip:
            self.flip()
        return
        
    def _corner_array_(self):
        size = self.size
        arr_1=np.ones([size,size])
        arr_0=np.zeros([size,size])
        arr_00=np.hstack([arr_0,arr_0,arr_0])
        arr_0_0=np.vstack([arr_0,arr_0,arr_0,arr_0,arr_0])
        arrt=np.hstack([arr_0,arr_1,arr_1])
        arrm=np.hstack([arr_0,arr_1,arr_0])
        arrb=np.hstack([arr_0,arr_0,arr_0])
        arr=np.vstack([arr_00,arrt,arrm,arrb,arr_00])
        arr=np.hstack([arr_0_0,arr,arr_0_0])
        self.arr = arr
        height, length = arr.shape
        self.center = (height/2, length/2)
        self.target = (height*0.5996, length*0.3996)
        correction = (self.target[0]-self.center[0], self.target[1] - self.center[1])
        self.correction = (correction[0], correction[1])
        return
        
    def _square_array_(self):
        size = self.size
        arr_1=np.ones([size,size])
        arr_0=np.zeros([size,size])
        arr_00=np.hstack([arr_0,arr_0,arr_0])
        arr_0_0=np.vstack([arr_0,arr_0,arr_0,arr_0,arr_0])
        arrt=np.hstack([arr_0,arr_0,arr_0])
        arrm=np.hstack([arr_0,arr_1,arr_0])
        arrb=np.hstack([arr_0,arr_0,arr_0])
        arr=np.vstack([arr_00,arrt,arrm,arrb,arr_00])
        arr=np.hstack([arr_0_0,arr,arr_0_0])
        self.arr = arr
        height, length = arr.shape
        self.center = (height/2, length/2)
        self.target = self.center
        correction = (self.target[0]-self.center[0], self.target[1] - self.center[1])
        self.correction = (correction[0], correction[1])
        return
    
    def _gamma_array_(self):
        size = self.size
        arr_1=np.ones([size,size])
        arr_0=np.zeros([size,size])
        arr_00=np.hstack([arr_0,arr_0,arr_0])
        arr_0_0=np.vstack([arr_0,arr_0,arr_0,arr_0,arr_0])
        arrt=np.hstack([arr_0,arr_1,arr_1])
        arrm=np.hstack([arr_0,arr_1,arr_0])
        arrb=np.hstack([arr_0,arr_1,arr_0])
        arr=np.vstack([arr_00,arrt,arrm,arrb,arr_00])
        arr=np.hstack([arr_0_0,arr,arr_0_0])
        self.arr = arr
        height, length = arr.shape
        self.center = (height/2, length/2)
        self.target = (height*0.5996, length*0.3996)
        correction = (self.target[0]-self.center[0], self.target[1] - self.center[1])
        self.correction = (correction[0], correction[1])
        return
        
    def _set_target_(self):
        pass
    
    def import_array(self, array):
        self.arr = array
        height, length = self.arr.shape
        self.center = (height/2, length/2)
        self._set_target_()
        
    def _invert_(self):
        arr_flat = self.arr.flatten()
        og_shape = self.arr.shape
        inverted_arr = np.empty_like(arr_flat)
        for i, j in enumerate(arr_flat):
            if j == 1:
                inverted_arr[i] = 0
            elif j == 0:
                inverted_arr[i] = 1
        inverted_arr = inverted_arr.reshape(og_shape)
        self.arr = inverted_arr
        
    def rotate(self, degrees):
        '''
        Function to rotate feature array in integer multiples of 90 degrees counterclockwise.
        
        Input:
        --------
        degrees : int
            desired rotation in degrees counterclockwise. Must be an integer multiple of 90 degrees!
        '''
        
        #Rotate the feature array itsef
        if degrees % 90 != 0:
            raise ValueError('degrees must be an integer multiple of 90 degrees')
        rotations = degrees/90
        self.arr = np.rot90(self.arr, rotations)
        
        #Apply rotation to target point coordinates
        degrees = 360-degrees
        sin = np.sin(np.deg2rad(degrees))
        cos = np.cos(np.deg2rad(degrees))
        x_t, y_t = self.target
        x_c, y_c = self.center
        x_t -= x_c
        y_t -= y_c
        x_n = x_t * cos - y_t * sin
        y_n = x_t * sin + y_t * cos
        x_t = x_n + x_c
        y_t = y_n + y_c
        self.target = (x_t, y_t)
        correction = (self.target[0]-self.center[0], self.target[1] - self.center[1])
        self.correction = (correction[0], correction[1])
        return
    
    def flip(self):
        '''
        Flips feature array along its vertical axis, essentially providing a
        mirror image of the original feature
        '''
        #Flip the feature array itself
        self.arr = np.flip(self.arr, 1)
        #Flip coordinates of target point
        x_t, y_t = self.target
        x_c, y_c = self.center
        diff_x = x_c - x_t
        x_n = x_c + diff_x
        self.target = (x_n, y_t)
        correction = (self.target[0]-self.center[0], self.target[1] - self.center[1])
        self.correction = (correction[0], correction[1])
        return
        
                

In [None]:
class FeatureFinder(object):
    '''
    '''
    def __init__(self, function=None, feature=None):
        self.function = function
        self.feature = feature
        
    def wiggle(self, rotation, scaling, shift_x, shift_y):
        arr = cv2.resize(self.feature.arr,(window_size*2, window_size*2),interpolation=cv2.INTER_CUBIC)
        rotation=-30+rotation*60
        scaling=0.7+scaling*0.6
        shift_x=int(-5+shift_x*10)
        shift_y=int(-5+shift_y*10)

        rotation_matrix = cv2.getRotationMatrix2D((window_size, window_size), rotation, scaling)
        altered_not_shifted = cv2.warpAffine(arr, rotation_matrix, (2*window_size, 2*window_size))
        shifted = np.roll(np.roll(altered_not_shifted,shift_x,axis=1),shift_y,axis=0)

        return shifted 
    
    def sliding_function(self,array,function,feature,window_size=32,step=1):
        a1=array
        x_1=window_size
        y_1=window_size
        a2=a1[x_1-window_size:x_1+window_size,y_1-window_size:y_1+window_size]
        output=np.ndarray.flatten(function(a2,feature,window_size))
        x_dim=len(np.arange(window_size,a1.shape[0]-window_size,step))
        y_dim=len(np.arange(window_size,a1.shape[1]-window_size,step))
        transformed=np.zeros([x_dim,y_dim,len(output)],dtype='float64')
        #print(x_dim,y_dim)
        x_count=0
        y_count=0
        for i in range(window_size,a1.shape[0]-window_size,step):
            for j in range(window_size,a1.shape[1]-window_size,step):

                x_1=i
                y_1=j
                a2=a1[x_1-window_size:x_1+window_size,y_1-window_size:y_1+window_size]

                a3=function(a2, feature, window_size)
                try:
                    transformed[y_count,x_count]=np.ndarray.flatten(a3)
                except TypeError:
                    transformed[y_count,x_count]=0
                x_count=x_count+1
            x_count=0
            y_count=y_count+1
        return transformed
    
    def search_image(self, image, f_count, feature=None):
        if feature == None:
            feature = self.feature
            
        return fit_parameters
    
    def generate_heatmap(self, image, step, window_size):
        DD = self.sliding_function(image,self.function,self.feature,window_size,step)
        heatmap = np.real(DD[:,:,0])
        return heatmap

# Main Body

### Initial Settings and Image Read-in

Load MALDI training images from binary file

In [None]:
data_dir = r'/home/tyler/Documents/Work/Ovchinnikova/ionmaps/'
images_f_name = 'slide2_experimental_positive_ionMaps.pik'
images_f_name = os.path.join(data_dir, images_f_name)
images_file = open(images_f_name, 'rb')
maldi_image_list = list(pickle.load(images_file))
images_file.close()

In [None]:
image_count = len(maldi_image_list)
i_x, i_y = maldi_image_list[0].shape
a_factor = 1/image_count
if a_factor < 0.05:
    a_factor = 0.05
fig, ax = plt.subplots(1)
plt.title('Sum image')
for image in maldi_image_list:
    ax.imshow(image, cmap='Greys_r', alpha=a_factor)
print('Imported %i images'%(len(maldi_image_list)))

### Feature template generation

In [None]:
resolution = 50 # Pixel-to-pixel distance, in micrometers
monomer_size = 250    # Length of feature monomers, in micrometers
length = monomer_size*2
inversion_mode = False   #Set to True when fiduciary markers appear as dark features instead of as bright features in ion images
template = 'corner'
buffer = 0
buffered_length = length + length * buffer
length_px=buffered_length*(1/resolution)
window_size=int(length_px/2*5/2)
feat = feature(monomer_size, template=template, rotation=180, flip=False)
if inversion_mode:
    feat._invert_()
arr = feat.arr
f_size = feat.arr.shape

In [None]:
fig, ax = plt.subplots(1)
ax.imshow(arr)
plt.title('Scaled Feature Template')
ax.scatter(feat.center[0], feat.center[1], color='red')
ax.scatter(feat.target[0], feat.target[1], color='green')
print('Feature center = %f,%f'%(feat.center[0], feat.center[1]))
print('Feature target = %f,%f'%(feat.target[0], feat.target[1]))
print('Correction vector = %f,%f'%(feat.correction[0], feat.correction[1]))

______________
Attempt to locate fiduciary markers in imported ion or PCA maps. Checks each map individually to build list of points identified in each image. Grabs two top most common values (mode) for x and y, infers list of identified rough anchor points from result.

In [None]:
ff = FeatureFinder(gamma_correlation, feat)
step = 3
times = []
print('Image :',' processing time :',' average time :',' total_time')
heatmaps = []
for idx, image in enumerate(maldi_image_list):
    ti = time.time()
    heatmap = ff.generate_heatmap(image, step, window_size)
    dt = time.time()-ti
    times.append(dt)
    average_time = sum(times) / len(times)
    heatmaps.append(heatmap)
    print(idx+1,':', dt,':', average_time,':', sum(times))

In [None]:
from skimage.feature import peak_local_max
min_distance = int(1/(step/10)+1)
rough_feature_top3 = []
for heatmap in heatmaps:
    local_maxima = peak_local_max(heatmap, min_distance=min_distance)
    local_maxima = np.flip(local_maxima,1)
    maxima_scores = []
    for y,x in local_maxima:
        heatmap_value = heatmap[x, y]
        maxima_scores.append(heatmap_value)
        sorted_maxima_scores = [score for score in maxima_scores]
        sorted_maxima_scores.sort()
        sorted_maxima_scores.reverse()
        top_3_coords = np.empty((3,2), dtype=int)
        n_top_3_coords = np.empty_like(top_3_coords, dtype=float)
        max_y, max_x = heatmap.shape
    for i in range(3):
        score = sorted_maxima_scores[i]
        index = maxima_scores.index(score)
        x, y = local_maxima[index]
        coordinate = (x, y)
        n_coordinate = (x/max_x, y/max_y)
        top_3_coords[i]=coordinate
        n_top_3_coords[i] = n_coordinate
    rough_feature_top3.append(top_3_coords)

In [1]:
full_pointlist = []
xs = []
ys = []
for a,b,c in rough_feature_top3:
    full_pointlist.append(tuple(a))
    full_pointlist.append(tuple(b))
    full_pointlist.append(tuple(c))
for x, y in full_pointlist:
    xs.append(x)
    ys.append(y)
mode_x1 = stats.mode(xs)[0]
mode_y1 = stats.mode(ys)[0]
while True:
    try:
        xs.remove(mode_x1)
    except ValueError:
        break
while True:
    try:
        ys.remove(mode_y1)
    except ValueError:
        break
mode_x2 = stats.mode(xs)[0]
mode_y2 = stats.mode(ys)[0]
if abs(mode_x1 - mode_x2) < 3:
    while True:
        try:
            xs.remove(mode_x2)
        except ValueError:
            break
    mode_x2 = stats.mode(xs)[0]
if abs(mode_y1 - mode_y2) < 3:
    while True:
        try:
            ys.remove(mode_y2)
        except ValueError:
            break
    mode_y2 = stats.mode(ys)[0]
rough_anchors = [(mode_x1[0], mode_y1[0]),(mode_x1[0], mode_y2[0]),(mode_x2[0], mode_y1[0])]
scaled_rough_anchors = [(x*step+window_size,y*step+window_size) for (x,y) in rough_anchors]

NameError: name 'rough_feature_top3' is not defined

In [None]:
fig, ax = plt.subplots(1)
plt.title('Rough Anchor Point Guesses')
for heatmap in heatmaps:
    ax.imshow(heatmap, alpha = 0.1, cmap = 'Greys_r')
for x, y in full_pointlist:
    ax.scatter(x,y,color='blue')
for x, y in rough_anchors:
    ax.scatter(x,y,color='red')

In [None]:
plt.imshow(maldi_image_list[3], cmap = 'hot')
for (x,y) in scaled_rough_anchors:
    plt.scatter(x,y,color='blue')
print(scaled_rough_anchors)

### Fine Fitting

In [None]:
anchors_list = []
kernel = np.ones((3,3),np.float32)/3
arr_n = cv2.resize(feat.arr, (2*window_size, 2*window_size))
rescale_factor = arr_n.shape[0]/feat.arr.shape[0]
feature_correction = [coord*rescale_factor for coord in feat.correction]
    
for image in maldi_image_list:
    top_3_fitted_parameters = np.empty((3,4,2))
    windows = []
    window_coords = []
    for idx, (x,y) in enumerate(scaled_rough_anchors):
        x_1 = x
        y_1 = y
        window_coords.append([x_1, y_1])
        check = image[y_1-window_size:y_1+window_size, x_1-window_size:x_1+window_size]
        check = cv2.filter2D(check, -1, kernel)
        check = check - np.min(check)
        check = check / np.max(check)
        find_marker = function_ls(function=ff.wiggle, target=check, feature=feat, epochs=200, alpha=0.5, decay=0.05, verbose=1)
        find_marker.load_guesses(np.array([0.5,0.5,0.5,0.5]))
        find_marker.verbose = False
        find_marker.fit()
        rotation, scaling, shift_x, shift_y = find_marker.parameters
        recalc_parameters = np.array([rotation*60-30, 
                                      scaling*0.6+0.7,
                                      shift_x*10-5,
                                      shift_y*10-5])
        fitted_parameters = np.empty((4,2))
        fitted_parameters[:,0] = find_marker.parameters
        fitted_parameters[:,1] = recalc_parameters
        top_3_fitted_parameters[idx] = fitted_parameters
    operators = []
    for i in range(3):
        rotation_degrees = top_3_fitted_parameters[i,0,1]
        #rotation_radians = np.deg2rad(rotation_degrees)
        S = top_3_fitted_parameters[i,1,1]
        R = cv2.getRotationMatrix2D((window_size, window_size), rotation_degrees, S)
        T = np.array([top_3_fitted_parameters[i,2,1], top_3_fitted_parameters[i,3,1]])
        op_set = [R,S,T]
        operators.append(op_set)
    anchors = []
    for i, op_set in enumerate(operators):
        R = op_set[0][:,:2]
        S = op_set[1]
        T = op_set[2]
        start_center = list(window_coords[i])
        fit_center = [start_center[j] + T[j] for j in range(2)]
        correction_vector = np.matrix(feature_correction)*R*S
        correction_vector[0,1] = correction_vector[0,1]
        anchor = [fit_center[j] + correction_vector[0,j] for j in range(2)]
        anchors.append(anchor)
    anchors=np.array(anchors)
    anchors_list.append(anchors)
    plt.figure()
    plt.imshow(image)
    #for (x,y) in anchors:
    #    plt.scatter(x,y,color='red')
    #    s = '(%f,% f)'%(x, y)
    #    plt.text(x+5, y+5, s)
    

In [None]:
fitted_anchors = []
for i in range(3):
    xs = [anchor_set[i][0] for anchor_set in anchors_list]
    ys = [anchor_set[i][1] for anchor_set in anchors_list]
    unique, counts = np.unique(xs, return_counts=True)
    max_count = np.max(counts)
    max_x = unique[np.where(counts==max_count)]
    unique, counts = np.unique(ys, return_counts=True)
    max_count = np.max(counts)
    max_y = unique[np.where(counts==max_count)]
    fitted_anchor = (max_x[0], max_y[0])
    fitted_anchors.append(fitted_anchor)
    print(fitted_anchor)

In [None]:
fitted_anchors = []
for i in range(3):
    xs = np.array([anchor_set[i][0] for anchor_set in anchors_list])
    ys = np.array([anchor_set[i][1] for anchor_set in anchors_list])
    mean_x = np.mean(xs)
    mean_y = np.mean(ys)
    std_x = np.std(xs)
    std_y = np.std(ys)
    diffs_x = xs - mean_x
    diffs_y = ys - mean_y
    keep_x = np.where(np.abs(diffs_x) < std_x)[0]
    keep_y = np.where(np.abs(diffs_y) < std_y)[0]
    valid_points = []
    for idx in keep_x:
        if idx in keep_y:
            valid_points.append(idx)
    xs = [xs[i] for i in valid_points]
    ys = [ys[i] for i in valid_points]
    refined_mean_x = np.mean(xs)
    refined_mean_y = np.mean(ys)
    fitted_anchor = (refined_mean_x, refined_mean_y)
    fitted_anchors.append(fitted_anchor)
    print(fitted_anchor)
fitted_anchors = np.array(fitted_anchors, dtype=np.float32)

In [None]:
a_factor = 1/image_count
if a_factor < 0.05:
    a_factor = 0.05
fig, ax = plt.subplots(1)
plt.title('Sum image')
for image in maldi_image_list:
    ax.imshow(image, cmap='Greys_r', alpha=a_factor)
for (x,y) in fitted_anchors:
    plt.scatter(x,y,color='red')
plt.title('Fine fitting results')

In [None]:
pickled_sims_image_dir = r'/home/tyler/Documents/Work/Ovchinnikova/maldi-sims ion maps/sims/'
sims_pickled_fname = r'sims_nmf_maps.pik'
sims_pickled_fname = os.path.join(pickled_sims_image_dir, sims_pickled_fname)
maldi_ion_image_dir = r'/home/tyler/Documents/Work/Ovchinnikova/maldi-sims ion maps/maldi/'
maldi_pickled_fname = r'maldi_pca_map_dump.pik'
maldi_pickled_fname = os.path.join(maldi_ion_image_dir, maldi_pickled_fname)
output_dir = r'/home/tyler/Documents/Work/Ovchinnikova/maldi-sims ion maps/coregistered'

In [None]:
sims_anchors = np.array([[212,325],
                         [212,887],
                         [1425,325]], dtype=np.float32)
with open(sims_pickled_fname,'rb') as f:
    pickled_maps = pickle.load(f)
pickled_shape = pickled_maps[0].shape

In [None]:
m = cv2.getAffineTransform(fitted_anchors, sims_anchors)

In [None]:
m

In [None]:
with open(maldi_pickled_fname, 'rb') as f:
    maldi_ion_maps = pickle.load(f)
#registered_ion_maps = {}
registered_ion_maps = [cv2.warpAffine(np.flip(np.rot90(maldi_ion_maps[:,:,i],3),1), m, (pickled_shape[1], pickled_shape[0])) for i in range(20)]
#ion_map_masses = [key for key in maldi_ion_maps['mean'].keys()]
#for key, ion_map in maldi_ion_maps['mean'].items():
#    registered_ion_maps[key] = (cv2.warpAffine(ion_map, m, (pickled_shape[1], pickled_shape[0])))

In [None]:
len(registered_ion_maps)

In [None]:
maldi_pullmap = registered_ion_maps[1]
plt.figure()
plt.imshow(maldi_pullmap)
plt.figure()
plt.imshow(np.flip(np.rot90(maldi_ion_maps[:,:,1],3),1))

In [None]:
sims_pullmap = pickled_maps[4]
plt.imshow(sims_pullmap)

In [None]:

plt.figure()
plt.imshow(maldi_pullmap, cmap='Reds', alpha=0.8)
plt.imshow(sims_pullmap, cmap='Greens', alpha=0.4)
figname = 'ion_map_%s.png'%('maldi pca 3, sims nmf 5')
figname = os.path.join(output_dir, figname)
plt.savefig(figname)

In [None]:
plt.figure()
for image in mapped_maldi_images[:1]:
    plt.imshow(image, alpha=0.9, cmap='Reds')
for image in pickled_maps[:1]:
    plt.imshow(image, alpha=0.7, cmap='Greens')

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
type(maldi_ion_maps)

In [None]:
list(maldi_ion_maps['mean'].keys())[28]