# PX analysis 
1. Input experiment parameters
2. Generate a 10x scale down multi-channel image for quick QC
3. Register ch-ch using mutial information GPU
4. Register cycle-cycle using phase-correlation
5. Stitch images together with phase-correlation
6. Generate OME-Zarr image for Napari (using a seperate notebook, need to change env)

Note:
1. we use 3N code, ch4 is not processed
2. we assume Cyc1 and Ch1 is the brightest, and we align all data to Cy1, Ch1

## INPUT EXPERIMENT PARAMETERS

In [5]:
Tile_SZ = [3200, 5440]              # assuming no changes, global parameters
Col, Row, Z = 3, 1, 6               # pick a well and a Z
Zs     = [1,3,5,7]                  # for quick QC only
Tiles  = [1, 2, 3, 4]
Swaths = [1, 2, 3, 4]
Cycles = [1, 2, 3, 4, 5]
Channels= [1, 2, 3]                 # 3N
Cyc_0, Ch_0 = 0, 0                  # aligned all data to cyc-0, ch-0

XTALK_CORRECTION = False            # try xtalk

Data_Path = "../../data/"                       # raw data 
TMP_Path = f"../col_{Col}_row_{Row}/tmp/"       # intermediate folder
Result_Path = f"../col_{Col}_row_{Row}/result/"  # result folderd

Barcodes_Path = "./barcode.csv"                 # barcodes info
RootF = "Col_ccc_Swath_sss_Tile_00fff_Cyc_0ttt_row_rrr_z_zzz_seq_1.tiff"        #file format

import os
 
if not( os.path.exists(TMP_Path)):
    os.makedirs(TMP_Path)
if not( os.path.exists(Result_Path)):
    os.mkdir(Result_Path)

## QC

In [20]:
# creat a montage file to quickly check the quality of the sequencing data
# then create a files of all Z to check Z
import numpy as np  
import matplotlib.pyplot as plt 
import glob, os, skimage, tifffile

scale = 10       # reduced scale
tile_d_sz = np.ceil( np.array(Tile_SZ) / scale ).astype(int)

imgs = np.zeros( (len(Tiles)*tile_d_sz[0], len(Swaths)*tile_d_sz[1], len(Channels), 
                  len(Cycles), len(Zs)), dtype=np.uint16 )  #YXCTZ
                
imgZ= np.zeros((Tile_SZ[0], Tile_SZ[1], len(Channels), len(Cycles),len(Zs)), dtype=np.uint16)
sel_swath, sel_tile = 2, 2

img_f = f"{Data_Path}{RootF}"

cnt, cnt_max = 0,  10000
for i,cycle in enumerate(Cycles):
    for j,swath in enumerate(Swaths):
        for k,tile in enumerate(Tiles):
            for l,z in enumerate(Zs):
                f = img_f.replace( 'ccc', f'{Col}')
                f = f.replace('rrr', f'{Row}')
                f = f.replace('zzz', f'{z}')
                
                if cycle<10:
                    f = f.replace('ttt', f'0{cycle}')
                else:
                    f = f.replace('ttt', f'{cycle}')
                    
                f = f.replace('sss', f'{swath}')
                f = f.replace('fff', f'{tile}')
                
                img = skimage.io.imread(f)                  #YXC format
                img = img[:,:,:len(Channels)]
                if (swath == sel_swath) and (sel_tile==tile):
                    imgZ[:,:,:, i, l] = img                 #YXCTZ
                
                img_d = img[0:-1:scale, 0:-1:scale, :]
                
                grid_x = ( swath - 1 ) * tile_d_sz[1]
                
                if swath % 2:
                    y_coordinate = len(Tiles)+1 - tile
                    img_d = np.flip(img_d, axis=0)
                else:
                    y_coordinate = tile
                grid_y = (y_coordinate-1) * tile_d_sz[0]
                
                imgs[ grid_y:grid_y+tile_d_sz[0], grid_x:grid_x+tile_d_sz[1], :, 
                     i, l] = img_d   #YXCTZ
            
print(f"the shampe of img is: {imgs.shape}, Zs={Zs}")

# dimension_order = "TZCYX" from YXCTZ
dimension_order = 'TZCYX'
img_path = f'{TMP_Path}qc_col_{Col}_row_{Row}.tiff'

img_fiji = np.transpose( imgs, (3, 4, 2, 0, 1))       #TZCYX
tifffile.imwrite( img_path, img_fiji.astype('uint16'), shape = img_fiji.shape, 
                 imagej=True, metadata={'axes':dimension_order,} )

img_path = f'{TMP_Path}qc_col_{Col}_row_{Row}_swath_{sel_swath}_tile_{sel_tile}_Z.tiff'
img_fiji = np.transpose( imgZ, (3, 4, 2, 0, 1))       #TZCYX
tifffile.imwrite( img_path, img_fiji.astype('uint16'), shape = img_fiji.shape, 
                 imagej=True, metadata={'axes':dimension_order,} )



the shampe of img is: (1280, 2176, 3, 5, 4), Zs=[1, 3, 5, 7]


## CH REGISTRATION WITH MI

In [21]:
import numpy as np  
import pandas as pd
import glob, os, skimage, tifffile, sys 
sys.path.append( r"./globalign-main/")
import globalign, nd2cat 

def align(ref, tar):
    h, w, Q = ref.shape[0], ref.shape[1], 8
    ref2 = skimage.exposure.equalize_adapthist( ref,clip_limit=0.05 )
    tar2 = skimage.exposure.equalize_adapthist( tar,clip_limit=0.05 )

    ref8 = np.round( ref2/ref2.max()*255 ).astype(np.uint8)
    tar8 = np.round( tar2/tar2.max()*255 ).astype(np.uint8)
    
    ref_q = nd2cat.image2cat_kmeans( ref8,Q )
    tar_q = nd2cat.image2cat_kmeans( tar8,Q )
    M = np.ones( (h,w), dtype='bool' )
    
    overlap = 0.5
    grid_angles = 32
    max_angle   = .2
    refinement_param = {'n': 32, 'max_angle': max_angle}
    on_gpu = True 
    max_translation = 30
    
    param = globalign.align_rigid_and_refine(ref_q, tar_q, M, M, Q, Q, grid_angles, max_angle, 
                                            refinement_param=refinement_param, overlap=overlap, 
                                            enable_partial_overlap=True, normalize_mi=False, 
                                            on_gpu=on_gpu, save_maps=False)
    # print('Mutual information: ', param[0][0])
    # print('Rotation angle: ', param[0][1])
    # print('Translation: ', param[0][2:4])
    # print('Center of rotation: ', param[0][4:6])
    
    if np.max( np.abs(param[0][2:4]) ) < max_translation:
        img_recover = globalign.warp_image_rigid(ref,
                                                tar,
                                                param[0],
                                                mode='nearest',
                                            bg_value=0)
    else:
        img_recover = tar 
    return img_recover

def align_an_img(imgs):
    nx, ny = 8, 6*4
    pad    = 32
    grid_x = np.linspace(0, imgs.shape[1], nx).astype(int)
    grid_y = np.linspace(0, imgs.shape[0], ny).astype(int)

    imgs_aligned = imgs.copy()

    for i,x in enumerate(grid_x[:-1]):
        for j,y in enumerate(grid_y[:-1]):
            x0 = x - pad 
            x1 = grid_x[i+1]+pad
            y0 = y - pad 
            y1 = grid_y[j+1]+pad 
            
            if x0<=0: 
                x0, x00 = 0, 0
            else:
                x00 = pad 
                
            if x1>=imgs.shape[1]:
                x1 = imgs.shape[1]
                x11= x1-x0
            else:
                x11 = x1-x0-pad 
                
            if y0<=0:
                y0, y00 = 0, 0
            else:
                y00 = pad
            if y1>=imgs.shape[0]:
                y1 = imgs.shape[0]
                y11= y1-y0  
            else: 
                y11= y1-y0-pad
                
            ref = imgs[y0:y1, x0:x1, 0]
            for k in np.arange(1,3):
                tar = imgs[y0:y1, x0:x1, k]
                tar_aligned = align( ref, tar )
                imgs_aligned[y:grid_y[j+1], x:grid_x[i+1], k] = tar_aligned[y00:y11, x00:x11]
    return imgs_aligned 
      

In [23]:
# doing ch-ch alignment 
#create a big tiff file
dimension_order = 'CYX'
cnt, cnt_max = 0, 100000

img_f = f"{Data_Path}{RootF}"
try:
    for cyc in Cycles:        
        for swath in Swaths:
            img_swath = np.zeros( (len(Tiles)*Tile_SZ[0], Tile_SZ[1], len(Channels)), 
                                 dtype=np.uint16) # 'YXC'
            for tile in Tiles:
                f = img_f.replace( 'ccc', f'{Col}')
                f = f.replace('rrr', f'{Row}')
                f = f.replace('zzz', f'{Z}')
            
                if cyc<10:
                    f = f.replace('ttt', f'0{cyc}')
                else:
                    f = f.replace('ttt', f'{cyc}')
                
                f = f.replace('sss', f'{swath}')
                f = f.replace('fff', f'{tile}')
                img = skimage.io.imread(f)          #YXC
                img = np.squeeze( img[:, :, :len(Channels)] )
                
                if swath % 2:                       # already flipped here
                    img = np.flip(img, axis=0)
                    y_coordinate = len(Tiles)+1 - tile
                else:
                    y_coordinate = tile
                    
                grid_y = (y_coordinate-1) * Tile_SZ[0]
                img_swath[ grid_y:grid_y+Tile_SZ[0], :,:] = img
                
            img_fiji = align_an_img(img_swath)
            img_fiji = img_fiji.transpose(2,0,1)    # 'CYX'
                
            if cyc<10:
                of = f'{TMP_Path}swath_{swath}_cyc_00{cyc}_z_{Z}_chcorrected.tiff'
            else:
                of = f'{TMP_Path}swath_{swath}_cyc_0{cyc}_z_{Z}_chcorrected.tiff'
            tifffile.imwrite(of, img_fiji.astype('uint16'), 
                            shape=img_fiji.shape, imagej=True, 
                            metadata={'axes':dimension_order,})     
        cnt += 1           
        if cnt>cnt_max:
            raise StopIteration
except StopIteration:
    pass

## CYCLE REGISTRATION WITH PHASE

In [27]:
import scipy.ndimage, skimage.exposure 
from skimage.registration import phase_cross_correlation 
from scipy.ndimage import fourier_shift 

def cal_shift(ref, tar):
    ref = skimage.exposure.equalize_adapthist( ref, clip_limit=0.05)
    tar = skimage.exposure.equalize_adapthist( tar, clip_limit=0.05)
    shift, error, diffphase = phase_cross_correlation(ref, tar)
    # tar_aligned = fourier_shift( np.fft.fftn(tar), shift)
    # tar_aligned = np.fft.ifftn( tar_aligned )
    # tar_aligned = np.round( tar_aligned).astype( np.uint16 )
    # print(shift)
    return shift

def align(imgs):    # imgs: TYXC
    imgs_aligned = imgs.copy() 
    imgr = np.squeeze( np.max(imgs, axis=3))    # TYX
    ref  = np.squeeze( imgr[0, :,:])            # YX
    ncyc = imgr.shape[0]
    
    nx, ny = 4, 3*len(Swaths)
    pad    = 32
    grid_x = np.linspace(0, imgs.shape[2], nx).astype(int)
    grid_y = np.linspace(0, imgs.shape[1], ny).astype(int)
    max_shift = 40
    
    for cyc in np.arange(1, ncyc):
        tar = np.squeeze( imgr[cyc,:,:] )
        shift_global = cal_shift( ref, tar )   # global shift
        
        for i,x in enumerate(grid_x[:-1]):
            for j,y in enumerate(grid_y[:-1]):
                x0 = x - pad 
                x1 = grid_x[i+1]+pad
                y0 = y - pad 
                y1 = grid_y[j+1]+pad 
            
                if x0<=0: 
                    x0, x00 = 0, 0
                else:
                    x00 = pad 
                
                if x1>=imgs.shape[2]:
                    x1 = imgs.shape[2]
                    x11= x1-x0
                else:
                    x11 = x1-x0-pad 
                
                if y0<=0:
                    y0, y00 = 0, 0
                else:
                    y00 = pad
                if y1>=imgs.shape[1]:
                    y1 = imgs.shape[1]
                    y11= y1-y0  
                else: 
                    y11= y1-y0-pad
                
                ref_roi = ref[y0:y1, x0:x1]
                tar_roi = tar[y0:y1, x0:x1]
                shift_local = cal_shift( ref_roi, tar_roi )
                if np.max( np.abs(shift_local - shift_global) ) <= max_shift:
                    shift = shift_local
                else:
                    shift = shift_global
                
                for ch in np.arange( imgs.shape[3] ):
                    tar_roi = np.squeeze( imgs[cyc, y0:y1, x0:x1, ch]).astype(np.float32)
                    tar_aligned = fourier_shift( np.fft.fftn(tar_roi), shift)
                    tar_aligned = np.fft.ifftn( tar_aligned )
                    tar_aligned = np.round( tar_aligned).astype( np.uint16 )
                    
                    imgs_aligned[ cyc, y:grid_y[j+1], x:grid_x[i+1], ch] = tar_aligned[y00:y11, x00:x11]
    return imgs_aligned   
      

# create big swath files 
dimension_order = 'TYXC'    #loaded
dimension_order = 'TCYX'    #imageJ-fiji
cnt, cnt_max = 0, 1000

rootf = f"{TMP_Path}swath_sss_cyc_0ttt_z_zzz_chcorrected.tiff"
try:
    for swath in Swaths:
        # for tile in tiles:
        imgs = np.zeros((len(Cycles), Tile_SZ[0]*len(Tiles), Tile_SZ[1], 3), dtype=np.uint16)
        for i,cyc in enumerate(Cycles):        
            f = rootf.replace('zzz', f'{Z}')
            if cyc<10:
                f = f.replace('ttt', f'0{cyc}')
            else:
                f = f.replace('ttt', f'{cyc}')
                
            f = f.replace('sss', f'{swath}')
            img = skimage.io.imread(f) 
            imgs[i, :,:,:] = img            # notice cyc is from 1   
                               
        aligned = align(imgs)
        img_fiji= aligned.transpose(0,3,1,2)
        of = f'{TMP_Path}aligned_swath_{swath}_z_{Z}.npz'
        np.savez(file=of, img=img_fiji)
        # tifffile.imwrite(of, img_fiji.astype('uint16'), 
        #                 shape=img_fiji.shape, imagej=True, 
        #                 metadata={'axes':dimension_order,})
        cnt += 1        
        if cnt>cnt_max:
            raise StopIteration
except StopIteration:
    pass

  tar_aligned = np.round( tar_aligned).astype( np.uint16 )


## XTALK CORRECTION

In [4]:
# cross talk?
import numpy as np 
xtalk = [[1, .609, .005], 
         [.072, 1, .009],
         [.02,.05, 1]]
xtalk_inv = np.linalg.inv(xtalk)

def xtalk_c(img):
    # shape 'YXC'
    img = np.array( img ).astype(np.float32)
    x_inv = np.array( [[ 1.04585019e+00, -6.36947927e-01, 5.03280415e-04],
                       [-7.51467763e-02,  1.04621640e+00, -9.04021370e-03],
                       [-1.71596649e-02, -3.95718613e-02, 1.00044195e+00]] )
    h, w = img.shape[0], img.shape[1]
    img = np.reshape( img, (3, h*w))
    img_inv = np.einsum("ij,jk->ik", x_inv, img)
    img_inv[ img_inv<=0 ] = 0
    img_inv = np.round( img_inv ).astype(np.uint16)
    img_inv = np.reshape( img_inv,(h,w,3))
    
    return img_inv

rootf = f"{TMP_Path}aligned_swath_sss_z_zzz.npz"

if XTALK_CORRECTION:
    for swath in Swaths:
        f = rootf.replace('sss', f'{swath}')
        f = f.replace('zzz', f'{Z}')
        data = np.load(f) 
        data = data['img']
        data_corr = np.zeros( (data.shape[0], data.shape[1], data.shape[2], data.shape[3]), dtype=np.uint16 )
        for i in np.arange( data.shape[0] ):      # 'TCYX'
            tmp = data[i,:,:,:]
            tmp = tmp.transpose(1,2,0)            # 'YXC'
            tmp = xtalk_c(tmp)
            data_corr[i,:,:,:] = tmp.transpose(2,0,1)
        of = f"{TMP_Path}aligned_xtalk_corr_swath_{swath}_z_{Z}.npz"
        np.savez( file=of, img=data_corr )

## STITCHING

In [6]:
import scipy.ndimage, os,skimage, tifffile 
from skimage.registration import phase_cross_correlation 
from scipy.ndimage import fourier_shift 
import numpy as np

def cal_shift(ref, tar):
    ref = skimage.exposure.equalize_adapthist( ref, clip_limit=0.05)
    tar = skimage.exposure.equalize_adapthist( tar, clip_limit=0.05)
    shift, error, diffphase = phase_cross_correlation(ref, tar)
    # tar_aligned = fourier_shift( np.fft.fftn(tar), shift)
    # tar_aligned = np.fft.ifftn( tar_aligned )
    # tar_aligned = np.round( tar_aligned).astype( np.uint16 )
    # print(shift)
    return shift

overlap = 570
if XTALK_CORRECTION:
    rootf = f"{TMP_Path}aligned_xtalk_corr_swath_sss_z_zzz.npz"
else:
    rootf = f"{TMP_Path}aligned_swath_sss_z_zzz.npz"
    
imgs  = []
for swath in Swaths:
    f = rootf.replace('sss', f'{swath}')
    f = f.replace('zzz', f'{Z}')
    data = np.load(f) 
    data = data['img']
    img = np.squeeze( data[Cyc_0, Ch_0, :,:] )
    imgs.append( img )
    
    
swath1  = imgs[0] 
h, w = swath1.shape[0], swath1.shape[1] 
offset = [] 
ref = swath1[:, -overlap:]
for img in imgs[1:]:
    tar = img[:, :overlap]
    shift = cal_shift(ref, tar)
    offset.append( shift )
    ref = img[:,-overlap:]
    
offset = np.round( offset ).astype( np.int16 )
print( f"offset = \n{offset}" )

offset = 
[[ 27 165]
 [-33 166]
 [ 31 163]]


In [7]:
#create a cyc_, ch_ tiff
data = np.zeros((len(Cycles), len(Channels), h, w*len(Swaths)), dtype=np.uint16 )                #tcyx
#swath 1
f = rootf.replace( 'sss', '1')
f = f.replace('zzz', f'{Z}')
img = np.load(f) 
swath1 = img['img']                                                #TCYX
data[:,:,:h, :w] = swath1
# print( swath1.shape, data.shape )

r_overlap = int( overlap/2 )
for k,swath in enumerate(Swaths[1:]):
    
    f = rootf.replace('sss', f'{swath}')
    f = f.replace('zzz', f'{Z}')
    im = np.load(f) 
    im = im['img']                                                  #tcyx
    
    shiftx = np.cumsum( offset[:k+1, 1])[-1] - overlap * (k+1)   #
    shifty = np.cumsum( offset[:k+1, 0])[-1]                     #
    
    l = shiftx + (k+1)*w + r_overlap
    r = shiftx + (k+2)*w
    
    for i,cyc in enumerate(Cycles):
        for j,ch in enumerate(Channels):
            tmp = np.squeeze( im[i,j,:,:] )
            tmp = np.roll(a=tmp,
                  shift = shifty,
                  axis=0)
            data[i,j, :, l:r] = tmp[:, r_overlap:]

In [8]:
for i,cyc in enumerate(Cycles):
    for j,ch in enumerate(Channels):
        f = f"{TMP_Path}c_{cyc}_ch_{ch}.tiff"
        img_fiji = data[i, j, :, :].astype( np.uint16 )
        dimension_order = 'YX'
        tifffile.imwrite(f, img_fiji.astype('uint16'), 
                         shape=img_fiji.shape, imagej=True, 
                         metadata={'axes':dimension_order,})  