# Extracting nuclei from 3D image stacks

This notebook allows user to load 3D microscopy images, segment nuclei, extract windows around nuclei, perform manual QC, and save them in a format suitable for downstream ML applications.

Output files will feature three tags that serve different purposes in forming A-P-N triplets:

sampleID: Negative images will have different sampleIDs (example: Zelda, Rpb1-nc12)
stackID: Positive images will have the same stackID (example: 20211115-zld-gfp-em1-03)
nucID: Positive images will have different nucIDs (for image augmentation purposes; allows multiple images from a single nucleus)

Example filename:

zld_2021115-zld-gfp-03_nuc1_01.pkl

sampleID: zld
stack_ID: 2021115-zld-gfp-03
nucID: nuc1

In [1]:
import flymovie as fm
from flymovie.simnuc import Sim
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
from importlib import reload
import scipy.ndimage as ndi
import skimage as ski
from importlib import reload
%load_ext autoreload
%autoreload 2

In [5]:
#from skimage import regionprops
def extract_nuclei_stack(stack, lmask, box_dims, mask_dil_struct=np.ones((2,10,10)), mask_nucs=True, centroid_margin=15):
    """Extract boxes containing nuclei from labelmask."""
    def adjust_centroid(im_shape, halflens, centroid, centroid_margin):
        """Move centroid to keep bounding box within image."""
        centroid_new = [centroid[0]]
        for dim in [1,2]:
            left_bound = centroid[dim] - halflens[dim]
            right_bound = centroid[dim] + halflens[dim]
            if left_bound < 0:
                if abs(left_bound) <= centroid_margin:
                    print('oh fixed that for you')
                    centroid_new = centroid_new + [halflens[dim]]
                else:
                    return None
            elif right_bound > im_shape[dim]:
                if abs(right_bound - im_shape[dim]) <= centroid_margin:
                    print('fuuuuck fixed that for you')
                    centroid_new = centroid_new + [im_shape[dim] - halflens[dim]]
                else:
                    return None
            else:
                centroid_new = centroid_new + [centroid[dim]]

        return centroid_new
    
    def get_box(stack, box_halflengths, centroid):
        """Extract ij box, keeping all z."""
        box = stack[
                    :,
                    (centroid[1] - box_halflengths[1]):(centroid[1] + box_halflengths[1]),
                    (centroid[2] - box_halflengths[2]):(centroid[2] + box_halflengths[2])
                ]
        return box

    def midpoint_max(stack):
        midpoint_z = round(stack.shape[0] / 2)
        return np.max(stack[midpoint_z])

    regions = ski.measure.regionprops(lmask)
    box_halflengths = (
            int(box_dims[0] / 2),
            int(box_dims[1] / 2),
            int(box_dims[2] / 2), 
        )
    nuc_count = 0
    ims = []
    for region in regions:
        centroid = ([round(x) for x in region.centroid])
        centroid = adjust_centroid(lmask.shape, box_halflengths, centroid, centroid_margin)

        if centroid is not None:
            im_box = get_box(stack, box_halflengths, centroid)
            lmask_box = get_box(lmask, box_halflengths, centroid)
            mask_box = np.where(lmask_box == region.label, 1, 0)
            mask_box = ndi.morphology.binary_dilation(mask_box, structure=mask_dil_struct)
            if mask_nucs:
                im_box = np.where(mask_box, im_box, 0)
            # Remove a number of bad nuclei by ensuring center z slice is occupied.
            if midpoint_max(im_box) == 0:
                print("Han, I got one!")
                continue
            ims.append(im_box)
            nuc_count += 1
    
    return ims

######################
def extract_nuclei_batch(folder, sampleID, box_dims, datestr, mask_dil_struct, thresh_boost=700):
    nucs = []
    filestems = []
    for f in os.listdir(folder):
        if f[0] == '.':
            continue

        stackID = datestr + '-' + f.split('.')[0]
        stack = fm.read_czi(os.path.join(folder, f), swapaxes=False)
        if stack.ndim == 4:
            stack = stack[:-1]
        if stack.ndim == 3:
            stack = [stack]
        for i in range(len(stack)):
            lmask = fm.segment_nuclei_3Dstack_rpb1(stack[i], usemax=False, sigma=4, min_nuc_center_dist=50, thresh_boost=thresh_boost)
            lmask = fm.labelmask_filter_objsize(lmask, 10_000, 2.5e5)
            ims = extract_nuclei_stack(stack[i], lmask, box_dims, mask_dil_struct, mask_nucs=True, centroid_margin=12)
            nucID = 0
            for nuc in ims:
                nucs.append(nuc)
                filestem = sampleID + '_' + stackID + '-' + str(i) + '_' + str(nucID)
                filestems.append(filestem)
                nucID += 1
        
    return np.array(nucs), filestems

#################################
def save_nucs(nucs, filestems, bad_nucs, box_dims, outfolder, z_range=[0], mask_dil_struct=np.ones((1,7,7)), 
        maskfolder=None, mask_target_dims=(100,100,100)):
    """Save nuclei, possibly with a range of z start positions."""
    def get_center_z(box):
        z, i, j = np.where(box > 0)
        return round((np.max(z) + np.min(z)) / 2)

    def save_mask(im, maskfolder, filename, mask_dil_struct, target_dims):
        mask = np.where(im > 0, 1, 0)
        mask = ndi.morphology.binary_erosion(mask, structure=mask_dil_struct)
        mask = ndi.zoom(mask, np.divide(target_dims, mask.shape), order=0)
        
        filepath = os.path.join(maskfolder, filename)
        with open(filepath, 'wb') as outfile:
            pickle.dump(mask, outfile)

    saved_count = 0
    stems_saved = {}
    nucs_saved = []
    
    for i in range(nucs.shape[0]):
        if i not in bad_nucs:
            center_z = get_center_z(nucs[i])
            start_z = center_z - int(box_dims[0] / 2)
            filestem = filestems[i]

            for z_add in z_range:
                z = start_z + z_add
                filename = '_'.join([filestem, str(z)]) + '.pkl'
                filepath = os.path.join(outfolder, filename)
                im_to_save = nucs[i][z:(z + box_dims[0])]
                if np.array_equal(im_to_save.shape, box_dims):
                    saved_count += 1
                    stems_saved[filestem] = 1
                    nucs_saved.append(im_to_save)
                    with open(filepath, 'wb') as outfile:
                        pickle.dump(im_to_save, outfile)
                    
                    if maskfolder is not None:
                        save_mask(im_to_save, maskfolder, filename, mask_dil_struct, mask_target_dims)
    
    print('Saved ' + str(saved_count) + ' images from ' + str(len(stems_saved.keys())) + ' nuclei.')
    return np.array(nucs_saved)

#c = save_nucs(a, b, [], (34,100,100), '/Volumes/stad3/2022-05-23/batchtest', z_range=[-2, 0, 2], maskfolder='/Volumes/stad3/2022-05-23/batchtest_masks')

In [3]:
im_folder = '/Volumes/stad3/2022-05-26/bcd-stacks'
# h2b, zld-wt, zld-cterm, zld-mutzn5, cp190, rpb1, zld-mutzn5-het
sampleID = 'bcd'
datestr = '20220526'
box_dims = (34,100,100)
mask_dil_struct = np.ones((1,7,7))

im_outfolder = '/Volumes/stad3/real_nuclei_batch/nucs/bcd'
mask_outfolder = '/Volumes/stad3/real_nuclei_batch/masks/bcd'

In [6]:
nucs, filestems = extract_nuclei_batch(im_folder, sampleID, box_dims, datestr, mask_dil_struct, thresh_boost=300)

threshold: 2185
threshold: 1921
fuuuuck fixed that for you
oh fixed that for you
threshold: 2027
fuuuuck fixed that for you
threshold: 3076
oh fixed that for you
oh fixed that for you
Han, I got one!
oh fixed that for you
oh fixed that for you
fuuuuck fixed that for you
fuuuuck fixed that for you
threshold: 2925
oh fixed that for you
Han, I got one!
Han, I got one!
fuuuuck fixed that for you
oh fixed that for you
fuuuuck fixed that for you
fuuuuck fixed that for you
fuuuuck fixed that for you
oh fixed that for you
fuuuuck fixed that for you
oh fixed that for you
threshold: 2735
oh fixed that for you
fuuuuck fixed that for you
oh fixed that for you
fuuuuck fixed that for you
oh fixed that for you
fuuuuck fixed that for you
threshold: 2627
oh fixed that for you
fuuuuck fixed that for you
oh fixed that for you
fuuuuck fixed that for you
fuuuuck fixed that for you
threshold: 2922
threshold: 2527
Han, I got one!
threshold: 3095
threshold: 3019
threshold: 3322
oh fixed that for you
Han, I go

In [7]:
fm.viewer([nucs.max(axis=1), nucs.max(axis=2), nucs.max(axis=3)])

interactive(children=(Dropdown(description='Color', index=8, options=('inferno', 'plasma', 'gray_r', 'viridis'…

In [8]:
# Save good nuclei.
bad_nucs = [
    19,28,29,45,80,99,117,141,142,146,231,
]

z_range = [-2,0,2]
nucs_saved = save_nucs(nucs, filestems, bad_nucs, box_dims, im_outfolder, z_range=z_range, mask_dil_struct=mask_dil_struct, maskfolder=mask_outfolder)

Saved 777 images from 357 nuclei.


In [None]:
fm.viewer(nucs_saved.max(axis=1))

In [None]:
m = fm.load_pickle('/Volumes/stad3/real_nuclei_batch/masks/rpb1/rpb1_20220523-rpb1-gfp-em2-01_7_5.pkl')
fm.viewer(m)

In [19]:
# Count nuclei.

folder = im_outfolder

folder = '/Users/michaelstadler/Bioinformatics/Projects/rpb1/data/real_nuclei/set2/rpb1'

samples = {}

for f in os.listdir(folder):
    if f[0] == '.':
        continue
    splits = f.split('_')
    key = '_'.join(splits[1:3])
    sample = splits[0]

    if sample in samples:
        if key not in samples[sample]:
            samples[sample][key] = 1
    else:
        samples[sample] = {}
        samples[sample][key] = 1

for sample in samples:
    print(sample + ': ' + str(len(samples[sample].keys())))

rpb1-real: 610
rpb1: 7


In [17]:
556 + 446 + 1724 + 1775 + 373 + 399 + 814 + 1024

7111

In [None]:
fm.simnuc.make_mask_file('/Users/michaelstadler/Bioinformatics/Projects/rpb1/data/real_masks/rpb1-nc12', '/Users/michaelstadler/Bioinformatics/Projects/rpb1/data/real_masks/mask_files/nc12_1.pkl')

In [None]:
masks = fm.load_pickle('/Users/michaelstadler/Bioinformatics/Projects/rpb1/data/real_masks/mask_files/nc13_1.pkl')
fm.viewer(masks, 5)

In [None]:
sm = ndi.gaussian_filter(stack, (1,4,4))

In [None]:
plt.hist(sm.flatten(), bins=500);

In [None]:
skimage.filters.thresholding.thre

In [None]:
skimage.filters.thresholding.threshold_multiotsu(sm.flatten(), classes=3)[1]

In [None]:
?skimage.filters.thresholding.threshold_multiotsu