In [None]:
import random
import glob
import json
from pathlib import Path
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
import h5py as h5
from skimage.exposure import rescale_intensity
from skimage.transform import rescale
from scipy.ndimage import gaussian_filter
import torch
from cellpose.models import Cellpose
from cellpose import dynamics

from scipy.ndimage.filters import maximum_filter1d
import fastremap

In [None]:
monkey_patch_large_cell = False

def get_masks(p, iscell=None, rpad=20):
    """ create masks using pixel convergence after running dynamics
    
    Makes a histogram of final pixel locations p, initializes masks 
    at peaks of histogram and extends the masks from the peaks so that
    they include all pixels with more than 2 final pixels p. Discards 
    masks with flow errors greater than the threshold. 
    Parameters
    ----------------
    p: float32, 3D or 4D array
        final locations of each pixel after dynamics,
        size [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
    iscell: bool, 2D or 3D array
        if iscell is not None, set pixels that are 
        iscell False to stay in their original location.
    rpad: int (optional, default 20)
        histogram edge padding
    threshold: float (optional, default 0.4)
        masks with flow error greater than threshold are discarded 
        (if flows is not None)
    flows: float, 3D or 4D array (optional, default None)
        flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. If flows
        is not None, then masks with inconsistent flows are removed using 
        `remove_bad_flow_masks`.
    Returns
    ---------------
    M0: int, 2D or 3D array
        masks with inconsistent flow masks removed, 
        0=NO masks; 1,2,...=mask labels,
        size [Ly x Lx] or [Lz x Ly x Lx]
    
    """
    
    pflows = []
    edges = []
    shape0 = p.shape[1:]
    dims = len(p)
    if iscell is not None:
        if dims==3:
            inds = np.meshgrid(np.arange(shape0[0]), np.arange(shape0[1]),
                np.arange(shape0[2]), indexing='ij')
        elif dims==2:
            inds = np.meshgrid(np.arange(shape0[0]), np.arange(shape0[1]),
                     indexing='ij')
        for i in range(dims):
            p[i, ~iscell] = inds[i][~iscell]

    for i in range(dims):
        pflows.append(p[i].flatten().astype('int32'))
        edges.append(np.arange(-.5-rpad, shape0[i]+.5+rpad, 1))

    h,_ = np.histogramdd(tuple(pflows), bins=edges)
    hmax = h.copy()
    for i in range(dims):
        hmax = maximum_filter1d(hmax, 5, axis=i)

    seeds = np.nonzero(np.logical_and(h-hmax>-1e-6, h>10))
    Nmax = h[seeds]
    isort = np.argsort(Nmax)[::-1]
    for s in seeds:
        s = s[isort]

    pix = list(np.array(seeds).T)

    shape = h.shape
    if dims==3:
        expand = np.nonzero(np.ones((3,3,3)))
    else:
        expand = np.nonzero(np.ones((3,3)))
    for e in expand:
        e = np.expand_dims(e,1)

    for iter in range(5):
        for k in range(len(pix)):
            if iter==0:
                pix[k] = list(pix[k])
            newpix = []
            iin = []
            for i,e in enumerate(expand):
                epix = e[:,np.newaxis] + np.expand_dims(pix[k][i], 0) - 1
                epix = epix.flatten()
                iin.append(np.logical_and(epix>=0, epix<shape[i]))
                newpix.append(epix)
            iin = np.all(tuple(iin), axis=0)
            for p in newpix:
                p = p[iin]
            newpix = tuple(newpix)
            igood = h[newpix]>2
            for i in range(dims):
                pix[k][i] = newpix[i][igood]
            if iter==4:
                pix[k] = tuple(pix[k])
    
    M = np.zeros(h.shape, np.uint32)
    for k in range(len(pix)):
        M[pix[k]] = 1+k
        
    for i in range(dims):
        pflows[i] = pflows[i] + rpad
    M0 = M[tuple(pflows)]

    # NB: commented out to remove hardcoded 40% image size max cell size filter
    # remove big masks
#     uniq, counts = fastremap.unique(M0, return_counts=True)
#     big = np.prod(shape0) * 0.4
#     bigc = uniq[counts > big]
#     if len(bigc) > 0 and (len(bigc)>1 or bigc[0]!=0):
#         M0 = fastremap.mask(M0, bigc)
    
    fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
    M0 = np.reshape(M0, shape0)
    return M0

# monkey-patch if desired
if monkey_patch_large_cell:
    dynamics.get_masks = get_masks

In [None]:
def load_all_details(in_file, normalize=True, percentiles=(2.5, 99.5)):
    imgs = []
    percentiles_ = []
    names = []
    
    with h5.File(in_file, 'r') as fd:
        details = [k for k in fd['experiment'].keys() if 'detail' in k]

        for detail_name in details:
            data = fd['experiment/{}/0/0'.format(detail_name)][...]
            
            # rescale intensity
            percentile = tuple(np.percentile(data, percentiles))
            
            if normalize:
                data = rescale_intensity(data, percentile, 'uint8').astype(np.uint8)
            
            imgs.append(data)
            percentiles_.append(percentile) # save percentiles in raw intensity
            names.append(detail_name) # save dataset name
            
    return imgs, percentiles_, names

In [None]:
in_files = glob.glob('/scratch/hoerl/auto_sir_dna_comp/20207*/*/*.h5')
in_files.sort()

in_files

In [None]:
def load_all_details(in_file, normalize=True, percentiles=(2.5, 99.5)):
    imgs = []
    percentiles_ = []
    names = []
    
    with h5.File(in_file, 'r') as fd:
        details = [k for k in fd['experiment'].keys() if 'detail' in k]

        for detail_name in details:
            data = fd['experiment/{}/0/0'.format(detail_name)][...]
            
            # rescale intensity
            percentile = tuple(np.percentile(data, percentiles))
            
            if normalize:
                data = rescale_intensity(data, percentile, 'uint8').astype(np.uint8)
            
            imgs.append(data)
            percentiles_.append(percentile) # save percentiles in raw intensity
            names.append(detail_name) # save dataset name
            
    return imgs, percentiles_, names

In [None]:
from concurrent.futures import ThreadPoolExecutor

# percentiles for normalization
percentiles=(2.5, 99.8)
# whether to normalize per replicate (True) or per image (False)
normalize_per_replicate = True

# single-threaded version:
# loaded = {f: load_all_details(f) for f in in_files}

loaded = {}
with ThreadPoolExecutor() as tpe:
    futures = [tpe.submit(load_all_details, f, not normalize_per_replicate, percentiles) for f in in_files]
    for i, (f,future) in enumerate(zip(in_files, futures)):
        loaded[f] = future.result()    
        print(f'({i+1}/{len(futures)}): {f}')

In [None]:

imgs, percs, names = next(iter(loaded.values()))

idx = min(2, len(imgs)-1)
img = gaussian_filter(imgs[idx].squeeze().astype(np.float32), 5)

In [None]:
model = Cellpose(model_type='nuclei', net_avg=True, device=torch.device('cuda:0'))
masks, flows, styles, diams = model.eval([img], channels=[0,0], rescale=False, diameter=500, normalize=True, flow_threshold=0.3)

In [None]:
fig, axs = plt.subplots(ncols=2, figsize=(10,5))
axs[0].imshow(img.squeeze(), cmap='gray')
axs[1].imshow(masks[0].squeeze())

In [None]:
n_imgs_to_test = min(len(imgs), 10)
test_imgs = random.sample(imgs, n_imgs_to_test)
test_imgs = [gaussian_filter(img.squeeze().astype(np.float32), 5) for img in test_imgs]

masks, _, _, _ = model.eval(test_imgs, channels=[0,0], rescale=False, diameter=1000, normalize=True, flow_threshold=0.4, cellprob_threshold=0.)

In [None]:
for img, mask in zip(test_imgs, masks):
    fig, axs = plt.subplots(ncols=2, figsize=(10,5))
    axs[0].imshow(img.squeeze(), cmap='gray')
    axs[1].imshow(mask.squeeze(), interpolation='nearest')
    

## Use Cellpose to segment large images

E.g. spinning disk data, stitched overviews, ..

### get list of files to process

In [None]:
import glob
from nd2reader import ND2Reader
from skimage.io import imread


# new DAPI-stained (ctrl) IMR90 (Dec 2022)
# files = glob.glob('/data/cooperation_data/ArgyrisPapantonis-nuclear_architecture/Hartmann_Harz/IMR90_30112022/*.nd2')

# stitched overviews from autoSTED
files = glob.glob('/scratch/hoerl/20230507_imr90_stitching/20230507_imr90_ov_stitch_output/*.tif')

files = sorted(files)
files

In [None]:
# TODO: only loads single plane images at the moment

imgs = []
for file in files:
    with ND2Reader(file) as reader:
        img = np.array(reader[0])
    print(f'loaded {file}')
    imgs.append(img)

In [None]:
imgs = []
for file in files:
    imgs.append(imread(file))
    print(f'loaded {file}')    

In [None]:
# max project in z
imgs = [img.max(axis=0) for img in imgs]

### CLAHE to increase contrast before segmentation

In [None]:
from skimage.exposure import equalize_adapthist
from matplotlib import pyplot as plt

idx  = 12
img = imgs[idx]

img_clahe = equalize_adapthist(img, kernel_size=100, clip_limit=0.005)

fig, axs = plt.subplots(ncols=2, figsize=(12, 6))
axs[0].imshow(img, cmap='gray')
axs[1].imshow(img_clahe, cmap='gray')


In [None]:
imgs_clahe = []

for idx in range(len(imgs)):
    
    # do CLAHE
    img_clahe = equalize_adapthist(imgs[idx], kernel_size=500, clip_limit=0.02)
    
    # keep black borders at exactly 0
    # TODO: really necessary?
    img_clahe[imgs[idx] == 0] = 0
    
    imgs_clahe.append(img_clahe)
    
    print(f'CLAHE on {files[idx]} done.')

### segment with Cellpose

Verdict: work reasonably well, CLAHE sometimes helps, some cells omitted, so maybe refine Cellpose parameters

In [None]:
diam = 70
flow_thresh = 0.7

idx = 35
img = imgs[idx]

dev = torch.device('cuda:1')
model = Cellpose(model_type='nuclei', net_avg=True, device=dev)

(mask,), _, _, _ = model.eval([img], channels=[0,0], diameter=diam, flow_threshold=flow_thresh)

fig, axs = plt.subplots(ncols=2, figsize=(12, 6))
axs[0].imshow(img, cmap='gray')
axs[1].imshow(mask, cmap='prism', interpolation='nearest')

In [None]:
diam = 70
flow_thresh = 0.7

masks = []

for img, file in zip(imgs, files):
    # TODO: maybe do in one call?
    (mask,), flows, styles, diams = model.eval([img], channels=[0,0], diameter=diam, flow_threshold=flow_thresh)
    masks.append(mask)
    print(f'segmented {file}')

### Alternative: segment with just thresholding and a bit of morphological ops

Verdict: Might need more postprocessing, esp. when doing CLAHE before

In [None]:
from skimage.filters import threshold_li
from skimage.measure import label, regionprops
from skimage.morphology import binary_closing, remove_small_holes, remove_small_objects, disk


def clear_large_objects(labels, max_size, bgval=0, in_place=False):
    
    if not in_place:
        labels = labels.copy()
        
    for rprop in regionprops(labels):
        if rprop.area > max_size:
            labels[rprop.slice][rprop.image] = bgval
    
    return labels


hole_and_small_object_size = 500
disk_radius = 3
max_object_size = 500 * 500

masks = []

for img, file in zip(imgs, files):
    
    # NOTE: we calculate the threshold from pixels > 0 to ignore borders created by stitching
    t = threshold_li(img[img>0])
    mask = img > t
    
    # some morphological cleanup
    mask = remove_small_holes(mask, hole_and_small_object_size)
    mask = remove_small_objects(mask, hole_and_small_object_size)
    mask = binary_closing(mask, disk(disk_radius))    
    
    labels = label(mask)
    labels = clear_large_objects(labels, max_object_size)    
    
    masks.append(labels)
    print(f'segmented {file}')

### quick visualization of one image (cut)

In [None]:
from matplotlib import pyplot as plt

idx = 4
cut = (slice(0,3000), slice(0,3000))
# clim = (0, 1200)

fig, axs = plt.subplots(ncols=2, figsize=(12,6))

axs[0].imshow(imgs[idx][cut], cmap='gray')
axs[1].imshow(masks[idx][cut], cmap='rainbow', interpolation='nearest')

fig.tight_layout()

### save masks / labels

In [None]:
from pathlib import Path
import os
from tifffile import imwrite

# folder name if using cellpose
folder_name = f'segmentation_cellpose_maxproj_d{diam}_ft{str(flow_thresh).replace(".", "")}'
# if using just thresholding
# folder_name = 'segmentation_threshold'

# make folder if necessary
folder = (Path(files[0]).parent / folder_name)
if not folder.exists():
    folder.mkdir()

# save masks
for mask, file in zip(masks, files):
    outfile = folder / (Path(file).name + '_labels.tif')
    imwrite(outfile, mask)

In [None]:
# check that we have results in the right place

!ls -l /data/cooperation_data/ArgyrisPapantonis-nuclear_architecture/Hartmann_Harz/IMR90_30112022/$folder_name/