In [16]:
import os
import pickle

import pandas as pd
import numpy as np

from skimage import transform
from skimage.transform import downscale_local_mean

import matplotlib.pyplot as plt

import napari

from pystackreg import StackReg
import pystackreg

In [3]:
dir_path = r'Z:\Wayne\20210716_sen_sig_IF'
df_name = 'df_A3.pkl'

labels_dir = r'D:\toDel\results'

save_dir = dir_path

In [4]:
# read in the data
myData = pd.read_pickle(os.path.join(dir_path,df_name))

In [12]:
# read in all the labels
labels_list = os.listdir(labels_dir)
labels_list.sort()

labels_im_list = []
for lab_im_name in labels_list:
    
    lab_im = plt.imread(os.path.join(labels_dir,lab_im_name))
    labels_im_list.append(lab_im)
    
labels_im = np.array(labels_im_list)  

In [13]:
# check if you have an expected number of files to align
print(f'Number of label images {labels_im.shape[0]}')
print(f'Number of unique rounds {len(set(myData.alignRound))}')

Number of label images 3
Number of unique rounds 3


## Registration

In [14]:
def find_transformation(labels_im,downscale_factor = 1):

    # resize the image
    labels_small = labels_im>0
    labels_small = downscale_local_mean(labels_small,(1,downscale_factor,downscale_factor))
    
    # find transformation
    tf = StackReg.RIGID_BODY
    sr = StackReg(tf)
    
    tmat_small = sr.register_stack(labels_small, axis=0, reference='first', verbose=True)
    
    # rescale transformation
    if downscale_factor > 1:
        
        tmat =[]

        for tranform_matrix in tmat_small:

            eu_transform_small = transform.EuclideanTransform(tranform_matrix)

            eu_transform = transform.EuclideanTransform(translation = eu_transform_small.translation * downscale_factor,
                                                        rotation = eu_transform_small.rotation)

            tmat.append(eu_transform)
            
    return tmat,tmat_small

def apply_transforms_set(tmat,movie):
    
    res = []
    
    for index,tranform_matrix in enumerate(tmat):
    
        eu_transform = transform.EuclideanTransform(tranform_matrix)

        # if you want to check only transformation without rotation
        #eu_transform_small = transform.EuclideanTransform(translation = eu_transform_small.translation, rotation = 0)

        temp = transform.warp(movie[index,:,:],eu_transform,output_shape=movie[index].shape)

        res.append(temp)

    res = np.array(res)
    
    return res

In [17]:
downscale_factor = 4

tmat,tmat_small = find_transformation(labels_im,downscale_factor = downscale_factor)

100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:48<00:00, 24.48s/it]


In [18]:
# test transformations

labels_small = downscale_local_mean(labels_im,(1,downscale_factor,downscale_factor))
labels_small_aligned = apply_transforms_set(tmat_small,labels_small)

viewer = napari.Viewer()
viewer.add_image(labels_small[0],blending ='additive',colormap='gray')
viewer.add_image(labels_small,blending ='additive',colormap='red')
viewer.add_image(labels_small_aligned,blending ='additive',colormap='green')

  zoom = np.min(canvas_size / scale)


<Image layer 'labels_small_aligned' at 0x2237f80f0c8>

In [19]:
# save transformations
save_file_path = os.path.join(save_dir,df_name.replace('df','tmat')) 
pickle.dump(tmat, open(save_file_path, "wb"))