In [1]:
import numpy as np
from cellpose import models, core
from pathlib import Path
import os,glob
from cellpose import io
import timeit
from omnipose import core as oc
import ncolor.label as nl
import ncolor
import read_roi as rr
import tifffile
import pickle as pk
import momia2 as mo

## Timelapse video XY drift correction

In [1]:
# define functions of drift correction
def xydrift_correction(target_img, shift, max_drift=1000):
    
    if max(np.abs(shift)) <= max_drift:
        return shift_image(target_img, np.array(shift))
    else:
        return target_img
    
def get_xydrift(ref_img, target_img):
    from skimage import registration
    shift, error, _diff = registration.phase_cross_correlation(ref_img, target_img, upsample_factor=10)
    return shift

def shift_image(img, shift):
    from scipy import ndimage as ndi
    """
    correct xy drift between phase contrast image and fluorescent image(s)
    :param img: input image
    :param shift: subpixel xy drift
    :return: drift corrected image
    """
    offset_image = ndi.fourier_shift(np.fft.fftn(img), shift)
    offset_image = np.fft.ifftn(offset_image)
    offset_image = np.round(offset_image.real)
    offset_image[offset_image <= 0] = 1
    # rescaled to avoid int16 overflow
    offset_image[offset_image>=65530] = 65530
    return offset_image.astype(np.uint16)


def correct_tl(pos,softread_data,n_channels=2,ref_channel=1):
    pos_data = {i:[] for i in range(n_channels)}
    output_data = []
    for c in range(n_channels):
        for t in range(softread_data['n_timepoints']):
            pos_data[c].append(mo.io.get_slice_by_index(softread_data,channel=c,position=pos,time=t))
        pos_data[c]=np.array(pos_data[c])
    
    drift_corrected = {i:np.zeros(pos_data[i].shape) for i in range(n_channels)}
    last_drift = np.array([0,0])
    max_drift = 0
    print(len(pos_data[0]))
    for t in range(softread_data['n_timepoints']):
        t_frame = []
        if t == 0:
            for c in range(n_channels):
                t_frame.append(pos_data[c][t])
        else:
            last_drift = last_drift + get_xydrift(pos_data[ref_channel][t-1],pos_data[ref_channel][t])
            max_drift = max(abs(np.max(last_drift)),max_drift)
            for c in range(n_channels):
                t_frame.append(xydrift_correction(pos_data[c][t],last_drift))
        output_data.append(np.array(t_frame))
    output_data = np.array(output_data)
    max_drift = int(round(max_drift))
    cropped = output_data[:,:,max_drift:-max_drift,max_drift:-max_drift]
    return cropped

def correct_tl_simp(pos,timearray,n_channels=1,ref_channel=0):
    pos_data = {i:[] for i in range(n_channels)}
    output_data = []
    for c in range(n_channels):
        for t in range(softread_data['n_timepoints']):
            pos_data[c].append(mo.io.get_slice_by_index(softread_data,channel=c,position=pos,time=t))
        pos_data[c]=np.array(pos_data[c])
    
    drift_corrected = {i:np.zeros(pos_data[i].shape) for i in range(n_channels)}
    last_drift = np.array([0,0])
    max_drift = 0
    print(len(pos_data[0]))
    for t in range(softread_data['n_timepoints']):
        t_frame = []
        if t == 0:
            for c in range(n_channels):
                t_frame.append(pos_data[c][t])
        else:
            last_drift = last_drift + get_xydrift(pos_data[ref_channel][t-1],pos_data[ref_channel][t])
            max_drift = max(abs(np.max(last_drift)),max_drift)
            for c in range(n_channels):
                t_frame.append(xydrift_correction(pos_data[c][t],last_drift))
        output_data.append(np.array(t_frame))
    output_data = np.array(output_data)
    max_drift = int(round(max_drift))
    cropped = output_data[:,:,max_drift:-max_drift,max_drift:-max_drift]
    return cropped

In [None]:
# batch drift correction
# ignore field #20 due to suboptimal image quality
for f in sorted(glob.glob('/Volumes/JZSSD_temp/20190408_NQTF/*.nd2')):
    if '20.nd2' not in f:
        soft=mo.io.softread_file(f)
        header = f.split('/')[-1].split('.')[0]
        drift_corrected=correct_tl(0,soft,n_channels=2,ref_channel=1)
        tifffile.imwrite('/Volumes/JZSSD_temp/20190408_NQTF/driftcorrected_{}.tif'.format(header), 
                         drift_corrected,imagej=True,
                         metadata={'axes': 'TCYX'})

## Manually define ROIs using Fiji

## Image segmentation using Omnipose
### https://omnipose.readthedocs.io
### PMID: 36253643

In [2]:
models.MODEL_NAMES

['cyto',
 'nuclei',
 'cyto2',
 'bact_phase_cp',
 'bact_fluor_cp',
 'plant_cp',
 'worm_cp',
 'cyto2_omni',
 'bact_phase_omni',
 'bact_fluor_omni',
 'plant_omni',
 'worm_omni',
 'worm_bact_omni',
 'worm_high_res_omni']

In [3]:
model_name = 'bact_phase_omni'
model = models.CellposeModel(gpu=0, model_type=model_name)

2023-01-05 11:01:24,616 [INFO] >>bact_phase_omni<< model set to be used
2023-01-05 11:01:24,618 [INFO] >>>> using CPU


In [59]:
# batch segmentation
verbose = 0 # turn on if you want to see more output 
transparency = True # transparency in flow output
rescale=0.6 # give this a number if you need to upscale or downscale your images
omni = True # we can turn off Omnipose mask reconstruction, not advised 
flow_threshold = 0 # default is .4, but only needed if there are spurious masks to clean up; slows down output
resample = True #whether or not to run dynamics on rescaled grid or original grid 
chans = [0,0]

for f in sorted(glob.glob('/Volumes/JZSSD_temp/20221120_NQTF/pred_results/*_clips.npy')):
    header = f.split('_clips')[0]
    omni_mask_fname = header+'_omni_masks.npy'
    omni_flow_fname = header+'_omni_flows.npy'
    if not os.path.isfile(omni_mask_fname):
        images = np.load(f)
        masks, flows, styles = model.eval([x for x in images],
                                  channels=chans,
                                  rescale=rescale,
                                  transparency=transparency,
                                  flow_threshold=0.3,omni=omni,
                                  resample=resample,verbose=verbose)
        masks = np.array(masks)
        flows = np.array(flows)
        np.save(omni_mask_fname,masks)
        np.save(omni_flow_fname,flows)

2023-01-05 16:35:19,032 [INFO] 0%|          | 0/90 [00:00<?, ?it/s]
2023-01-05 16:35:21,134 [INFO] 1%|1         | 1/90 [00:02<03:06,  2.10s/it]
2023-01-05 16:35:22,956 [INFO] 2%|2         | 2/90 [00:03<02:50,  1.94s/it]
2023-01-05 16:35:24,624 [INFO] 3%|3         | 3/90 [00:05<02:37,  1.81s/it]
2023-01-05 16:35:26,240 [INFO] 4%|4         | 4/90 [00:07<02:29,  1.74s/it]
2023-01-05 16:35:27,753 [INFO] 6%|5         | 5/90 [00:08<02:20,  1.66s/it]
2023-01-05 16:35:29,243 [INFO] 7%|6         | 6/90 [00:10<02:14,  1.60s/it]
2023-01-05 16:35:30,709 [INFO] 8%|7         | 7/90 [00:11<02:09,  1.56s/it]
2023-01-05 16:35:32,170 [INFO] 9%|8         | 8/90 [00:13<02:05,  1.53s/it]
2023-01-05 16:35:33,653 [INFO] 10%|#         | 9/90 [00:14<02:02,  1.51s/it]
2023-01-05 16:35:35,113 [INFO] 11%|#1        | 10/90 [00:16<01:59,  1.50s/it]
2023-01-05 16:35:36,588 [INFO] 12%|#2        | 11/90 [00:17<01:57,  1.49s/it]
2023-01-05 16:35:38,070 [INFO] 13%|#3        | 12/90 [00:19<01:56,  1.49s/it]
2023-01-05 16

  flows = np.array(flows)


2023-01-05 16:37:42,555 [INFO] 0%|          | 0/90 [00:00<?, ?it/s]
2023-01-05 16:37:43,917 [INFO] 1%|1         | 1/90 [00:01<02:01,  1.36s/it]
2023-01-05 16:37:44,848 [INFO] 2%|2         | 2/90 [00:02<01:37,  1.11s/it]
2023-01-05 16:37:45,768 [INFO] 3%|3         | 3/90 [00:03<01:28,  1.02s/it]
2023-01-05 16:37:46,703 [INFO] 4%|4         | 4/90 [00:04<01:24,  1.01it/s]
2023-01-05 16:37:47,600 [INFO] 6%|5         | 5/90 [00:05<01:21,  1.05it/s]
2023-01-05 16:37:48,500 [INFO] 7%|6         | 6/90 [00:05<01:18,  1.07it/s]
2023-01-05 16:37:49,390 [INFO] 8%|7         | 7/90 [00:06<01:16,  1.09it/s]
2023-01-05 16:37:50,290 [INFO] 9%|8         | 8/90 [00:07<01:14,  1.09it/s]
2023-01-05 16:37:51,205 [INFO] 10%|#         | 9/90 [00:08<01:14,  1.09it/s]
2023-01-05 16:37:52,117 [INFO] 11%|#1        | 10/90 [00:09<01:13,  1.09it/s]
2023-01-05 16:37:53,054 [INFO] 12%|#2        | 11/90 [00:10<01:12,  1.09it/s]
2023-01-05 16:37:53,986 [INFO] 13%|#3        | 12/90 [00:11<01:12,  1.08it/s]
2023-01-05 16

KeyboardInterrupt: 