# Extracting nuclei from 3D image stacks

This notebook allows user to load 3D microscopy images, segment nuclei, extract windows around nuclei, perform manual QC, and save them in a format suitable for downstream ML applications.

Output files will feature three tags that serve different purposes in forming A-P-N triplets:

sampleID: Negative images will have different sampleIDs (example: Zelda, Rpb1-nc12)
stackID: Positive images will have the same stackID (example: 20211115-zld-gfp-em1-03)
nucID: Positive images will have different nucIDs (for image augmentation purposes; allows multiple images from a single nucleus)

Example filename:

zld_2021115-zld-gfp-03_nuc1_01.pkl

sampleID: zld
stack_ID: 2021115-zld-gfp-03
nucID: nuc1

In [138]:
import flymovie as fm
from flymovie.simnuc import Sim
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
from importlib import reload
import scipy.ndimage as ndi
import skimage as ski
from importlib import reload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [139]:
#from skimage import regionprops
def extract_nuclei(stack, lmask, box_dims, mask_dil_struct=np.ones((2,10,10)), mask_nucs=True, centroid_margin=15):
    """Extract boxes containing nuclei from labelmask."""
    def adjust_centroid(im_shape, halflens, centroid, centroid_margin):
        """Move centroid to keep bounding box within image."""
        centroid_new = [centroid[0]]
        for dim in [1,2]:
            left_bound = centroid[dim] - halflens[dim]
            right_bound = centroid[dim] + halflens[dim]
            if left_bound < 0:
                if abs(left_bound) <= centroid_margin:
                    print('oh fixed that for you')
                    centroid_new = centroid_new + [halflens[dim]]
                else:
                    return None
            elif right_bound > im_shape[dim]:
                if abs(right_bound - im_shape[dim]) <= centroid_margin:
                    print('fuuuuck fixed that for you')
                    centroid_new = centroid_new + [im_shape[dim] - halflens[dim]]
                else:
                    return None
            else:
                centroid_new = centroid_new + [centroid[dim]]

        return centroid_new
    
    def get_box(stack, box_halflengths, centroid):
        """Extract ij box, keeping all z."""
        box = stack[
                    :,
                    (centroid[1] - box_halflengths[1]):(centroid[1] + box_halflengths[1]),
                    (centroid[2] - box_halflengths[2]):(centroid[2] + box_halflengths[2])
                ]
        return box

    regions = ski.measure.regionprops(lmask)
    box_halflengths = (
            int(box_dims[0] / 2),
            int(box_dims[1] / 2),
            int(box_dims[2] / 2), 
        )
    nuc_count = 0
    ims = []
    for region in regions:
        centroid = ([round(x) for x in region.centroid])
        centroid = adjust_centroid(lmask.shape, box_halflengths, centroid, centroid_margin)

        if centroid is not None:
            im_box = get_box(stack, box_halflengths, centroid)
            lmask_box = get_box(lmask, box_halflengths, centroid)
            mask_box = np.where(lmask_box == region.label, 1, 0)
            mask_box = ndi.morphology.binary_dilation(mask_box, structure=mask_dil_struct)
            if mask_nucs:
                im_box = np.where(mask_box, im_box, 0)
            ims.append(im_box)
            nuc_count += 1
    
    return ims


def save_nucs(ims, bad_nucs, box_dims, sampleID, stackID, outfolder, z_range=[0]):
    """Save nuclei, possibly with a range of z start positions."""
    def get_center_z(box):
        z, i, j = np.where(box > 0)
        return round((np.max(z) + np.min(z)) / 2)

    saved_count = 0
    nucs_saved = {}
    nucs = []
    for nucID in range(len(ims)):
        if nucID not in bad_nucs:
            center_z = get_center_z(ims[nucID])
            start_z = center_z - int(box_dims[0] / 2)

            for z_add in z_range:
                z = start_z + z_add
                filepath = os.path.join(outfolder, '_'.join([sampleID, stackID, str(nucID), str(z)]) + '.pkl')
                im_to_save = ims[nucID][z:(z + box_dims[0])]
                if np.array_equal(im_to_save.shape, box_dims):
                    saved_count += 1
                    nucs_saved[nucID] = 1
                    nucs.append(im_to_save)
                    with open(filepath, 'wb') as outfile:
                        pickle.dump(im_to_save, outfile)
    
    print('Saved ' + str(saved_count) + ' images from ' + str(len(nucs_saved.keys())) + ' nuclei.')
    return nucs

    

def save_masks(ims, mask_dil_struct, stackID, outfolder, target_dims=(100,100,100)):
    """Save masks to a folder."""
    count = 0
    for im in ims:
        count += 1
        mask = np.where(im > 0, 1, 0)
        mask = ndi.morphology.binary_erosion(mask, structure=mask_dil_struct)
        filepath = os.path.join(outfolder, stackID + '_' + str(count) + '.pkl')
        with open(filepath, 'wb') as outfile:
            pickle.dump(mask, outfile)


In [144]:
# Set files and variables.
box_dims = (34,100,100)
mask_dil_struct = np.ones((1,7,7))
sampleID = 'rpb1'

output_folder = '/Volumes/stad3/real_nuclei_set3/nucs/rpb1'
mask_folder = '/Volumes/stad3/real_nuclei_set3/masks/rpb1'

stackID = '20220523-rpb1-em3-02'
stack_file = '/Volumes/stad3/2022-03-21/26983-1-1-em2-16.czi'

In [145]:
# Read and view stack.
stack = fm.read_czi(stack_file, swapaxes=False)
stack = stack[0:]
fm.viewer(stack, 8)

interactive(children=(Dropdown(description='Color', index=9, options=('prism', 'Reds', 'inferno', 'plasma', 'v…

In [149]:
# Segment nuclei.
lmask = fm.segment_nuclei_3Dstack_rpb1(stack, usemax=False, sigma=4, min_nuc_center_dist=50, display=False, thresh=4800)
lmask = fm.labelmask_filter_objsize(lmask, 10_000, 2.5e5)
fm.viewer(lmask, 8)

threshold: 4800


interactive(children=(Dropdown(description='Color', index=9, options=('prism', 'Reds', 'inferno', 'plasma', 'v…

In [134]:
# Extract nuclei, view and identify bad nuclei.
ims = extract_nuclei(stack, lmask, box_dims, mask_dil_struct, mask_nucs=True, centroid_margin=12)
fm.viewer([np.array(ims).max(axis=1), np.array(ims).max(axis=2)], 5)
fm.viewer([np.array(ims)], 5)

interactive(children=(Dropdown(description='Color', index=9, options=('prism', 'Reds', 'inferno', 'plasma', 'v…

interactive(children=(Dropdown(description='Color', index=9, options=('prism', 'Reds', 'inferno', 'plasma', 'v…

In [None]:
# Save good nuclei.
bad_nucs = []
nucs = save_nucs(ims, bad_nucs, box_dims, sampleID, stackID, output_folder, z_range=[-2,0,2])

In [None]:
# Save masks, if desired.
save_masks(nucs, mask_dil_struct, stackID, mask_folder)

In [None]:
# Count nuclei.

folder = output_folder

samples = {}

for f in os.listdir(folder):
    if f[0] == '.':
        continue
    splits = f.split('_')
    key = '_'.join(splits[1:3])
    sample = splits[0]

    if sample in samples:
        if key not in samples[sample]:
            samples[sample][key] = 1
    else:
        samples[sample] = {}
        samples[sample][key] = 1

for sample in samples:
    print(sample + ': ' + str(len(samples[sample].keys())))

In [None]:
fm.simnuc.make_mask_file('/Users/michaelstadler/Bioinformatics/Projects/rpb1/data/real_masks/rpb1-nc12', '/Users/michaelstadler/Bioinformatics/Projects/rpb1/data/real_masks/mask_files/nc12_1.pkl')

In [None]:
masks = fm.load_pickle('/Users/michaelstadler/Bioinformatics/Projects/rpb1/data/real_masks/mask_files/nc13_1.pkl')
fm.viewer(masks, 5)

In [154]:
np.exp(-0.05) ** 45

0.10539922456186437

In [None]:
sm = ndi.gaussian_filter(stack, (1,4,4))

In [None]:
plt.hist(sm.flatten(), bins=500);

In [None]:
skimage.filters.thresholding.thre

In [None]:
skimage.filters.thresholding.threshold_multiotsu(sm.flatten(), classes=3)[1]

In [None]:
?skimage.filters.thresholding.threshold_multiotsu