# Predict input data with CARE networks

Code from https://github.com/CSBDeep/CSBDeep/blob/master/examples/denoising3D/3_prediction.ipynb

## Predict "distance to nearest cell exterior"

In [None]:
import numpy as np
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread, imwrite
from csbdeep.models import CARE

In [None]:
modelname = 'care_bcm3d_target1_v2'

In [None]:
x = imread('training_data/full_semimanual-raw/test/images/im0.tif')
axes = 'ZYX'

In [None]:
model = CARE(config=None, name=modelname, basedir='models')
restored = model.predict(x, axes,n_tiles=(4, 4, 4))

In [None]:
imwrite('distance_to_nearest_cell_exterior.tif', restored, compression='zlib')

## Predict "proximity enhanced cell boundary"

-> Restart Kernel to clean GPU memory

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread, imwrite
from csbdeep.models import CARE

In [None]:
modelname = 'care_bcm3d_target2_v2'

In [None]:
x = imread('training_data/full_semimanual-raw/test/images/im0.tif')
axes = 'ZYX'

In [None]:
model = CARE(config=None, name=modelname, basedir='models')
restored = model.predict(x, axes,n_tiles=(4, 4, 4))

In [None]:
imwrite('proximity_enhanced_cell_boundary.tif', restored, compression='zlib')

# Classical post-processing

In [None]:
from tifffile import imread, imwrite
import matplotlib.pyplot as plt
import numpy as np

from skimage.filters import threshold_otsu
from skimage.measure import label, regionprops
from scipy.ndimage import binary_erosion, grey_dilation

In [None]:
target1 = imread('distance_to_nearest_cell_exterior.tif')
target2 = imread('proximity_enhanced_cell_boundary.tif')

In [None]:
def normalize(img_stack: np.ndarray, low=3, high=99.9) -> np.ndarray:
    p_low, p_high = np.percentile(img_stack, [low, high])
    return (img_stack - p_low) / (p_high - p_low)

## Post processing of "distance to nearest cell exterior"

In [None]:
plt.imshow(target1[16, 256:-256, 256:-256])

Predicted ‘distance to nearest cell exterior’ images were first normalized by a simple percentile-based normalization method

In [None]:
target1_normalized = normalize(target1)

After applying Otsu-thresholding to the ‘distance to nearest cell exterior’ image to obtain a binary image (Figure S3b), connected voxel clusters can be isolated and identified assingle cell objects by labeling connected regions

In [None]:
thresh = threshold_otsu(target1_normalized)
target1_binarized_ = target1_normalized > thresh

To split clusters that are only connected by one or two voxels, the boundary voxels of each object were set to zero before labeling connected

In [None]:
target1_binarized = binary_erosion(target1_binarized_)
target1_labels = label(target1_binarized)

In [None]:
f, (ax1, ax2, ax3) = plt.subplots(1,3)
ax1.hist(target1_normalized.flatten(), 100, label='Normalized target1');
ylim = ax1.get_ylim()
ax1.plot([thresh, thresh], ylim, 'r', label='Otsu_threshold')
ax1.set_yscale('log')
ax1.set_ylim(ylim)
ax1.legend()

ax2.imshow(target1_binarized_[16, 256:-256, 256:-256])

ax3.imshow(target1_binarized_[16, 256:-256, 256:-256])

ax4.imshow(target1_labels[16, 256:-256, 256:-256])

After labeling, the erased boundary voxels were added back to each object

In [None]:
target1_labels = grey_dilation(target1_labels, size=(2,2,2))

In [None]:
plt.imshow(target1_labels[16, 256:-256, 256:-256])

A conservative size-exclusion filter was applied: small objects with volume smaller than the radius cubed of the targeted cells were considered background noise and filtered out.

* voxel size 400 nm x 63nm x 63 nm
* radius: 0.2775 um Hartmann *et al.* **Nature Physics** (2019)

$r^3 = 0.0213 um^3$

$V(vx) = 0.4 nm \times 0.063 nm \times 0.063 nm = 0.0015876 um^3$

$V(thresh) = 0.0213 / 0.0015 =  14.2$


In [None]:
props = regionprops(target1_labels)

In [None]:
volumes = np.array([p.area for p in props])
v_thresh = 14.2

In [None]:
f, ax1 = plt.subplots(1,1)
ax1.hist(volumes, 100, label='Target1 volumes');
y_lim = ax1.get_ylim()
ax1.plot([v_thresh, v_thresh], y_lim, 'r', label='volume threshold')
ax1.legend()

In [None]:
obj_labels = np.array([p.label for p in props])

In [None]:
lut = np.arange(np.max(target1_labels) + 1, dtype=target1_labels.dtype)

In [None]:
lut[obj_labels[volumes <= v_thresh]] = 0

In [None]:
target1_labels = lut[target1_labels]

In [None]:
# Sanity check
props = regionprops(target1_labels)
assert all(np.array([p.area for p in props]) > v_thresh)

## Post-processing with "proximity enhanced cell boundary"

In [None]:
from scipy.stats import iqr
from skimage.segmentation import watershed

Objects that need further processing were found by evaluating its volume and solidity, i.e., the volume to convex volume ratio. Here, volume is defined as the number of voxels occupied by an object. Convex volume is defined as the number of voxels of a convex hull, which is the smallest convex polygon that encloses an object. The upper limit was found by using the interquartile rule, i.e. the upper limit is quartile 3 (Q3) plus 1.5 times interquartile range (IQR). If an object's volume or solidity is larger than the upper limit, it will be singled out for further processing.

In [None]:
props = regionprops(target1_labels)

In [None]:
solidity = np.array([p.solidity for p in props])
volume = np.array([p.area for p in props])
labels = np.array([p.label for p in props])

In [None]:
v_thresh = np.percentile(volume, 75) + 1.5 * iqr(volume)
s_thresh = np.percentile(solidity, 25) - 1.5 * iqr(solidity)

In [None]:
f, (ax1, ax2) = plt.subplots(1,2)

ax1.hist(volume, 100, label='Volume distribution');
ax1.set_xlabel('volume')
ax1.set_ylabel('frequency')
ax1.plot([v_thresh, v_thresh], [0, 400], 'r', label='volume threshold')
ax1.legend()

ax2.hist(solidity, 100, label='Solidity distribution');
ax2.plot([s_thresh, s_thresh], [0, 400], 'r', label='volume threshold')
ax2.set_xlabel('solidity')
ax2.set_ylabel('frequency')
ax2.legend()

**For me it looks like the radius threshold is not high enough**

**Why are there so many 1.0 values in the solidity?**

**It does not make sense to me to use the solidty threshold on 3 quartile of volume/ convex volume. It is probably better to apply it on the lower range**

In [None]:
undersegmented_labels = labels[(solidity < s_thresh) & (volume > v_thresh)]

In [None]:
len(undersegmented_labels)

In [None]:
lut = np.zeros(np.max(target1_labels) + 1,dtype=target1_labels.dtype)

In [None]:
lut[undersegmented_labels] = undersegmented_labels

In [None]:
target1_undersegmented = lut[target1_labels]

All these objects together generate a new binary image

In [None]:
labels_filterd = target1_undersegmented > 0

CNN-produced ‘proximity enhanced cell boundary’ images were first normalized by the same percentile-based normalization method

In [None]:
target2_normalized = normalize(target2)

Specifically, we generated a difference map by subtracting the ‘proximity enhanced cell boundary’ image from the ‘distance to nearest cell exterior’ image and then set all negative valued voxels to zero

In [None]:
factor = target1_mormalized - target2_mormalized
factor[factor < 0] = 0

This difference map was then multiplied by the binary image generated in Step 1

In [None]:
labels_filtered = labels_filterd * factor

In [None]:
plt.imshow(labels_filtered[16, 256:-256, 256:-256])

In [None]:
thresh = filters.threshold_otsu(labels_filtered)

In [None]:
plt.hist(labels_filtered.flatten(), 100)
plt.plot([thresh, thresh], [0, 10_000], 'r')
plt.yscale('log')

In [None]:
watershed_seed = labels_filtered > thresh

In [None]:
plt.imshow(watershed_seed[16])

In [None]:
watershed_seed_labels = label(watershed_seed)

In [None]:
props = regionprops(watershed_seed_labels)

In [None]:
volumes = np.array([p.area for p in props])
seed_labels = np.array([p.label for p in props])

In [None]:
h = plt.hist(volumes, 100);

In [None]:
lut = np.arange(watershed_seed_labels.max() + 1, dtype=watershed_seed_labels.dtype)

In [None]:
lut[seed_labels[volumes < 30]] = 0

In [None]:
mask = lut[watershed_seed_labels] > 0

In [None]:
watershed_seed_ = watershed_seed.copy()
watershed_seed_[~mask] = 0 

In [None]:
props = regionprops(label(watershed_seed_ > thresh))

In [None]:
volume_ = np.array([p.area for p in props])

In [None]:
plt.hist(volume_, h[1]);

In [None]:
plt.imshow(watershed_seed_[16])

In [None]:
plt.imshow(labels_filtered[16, 512-128:512+256, 512-128:512+256])

In [None]:
ws_result = watershed(
    -labels_filtered,
    markers=label(watershed_seed_ > thresh),
    mask=target1_undersegmented > 0
)

In [None]:
plt.imshow(ws_result[16, 512-128:512+256, 512-128:512+256])

In [None]:
ws_props = regionprops(ws_result)

ws_solidity = np.array([p.solidity for p in ws_props])
ws_volume = np.array([p.area for p in ws_props])
ws_labels = np.array([p.label for p in ws_props])

In [None]:
plt.hist(ws_solidity, 100);
plt.plot([s_thresh, s_thresh], [0, 12], 'r')

In [None]:
plt.hist(ws_volume, 100);
plt.plot([v_thresh, v_thresh], [0, 12], 'r')

In [None]:
undersegmented_labels1 = ws_labels[(ws_solidity < s_thresh) & (ws_volume > v_thresh)]

In [None]:
undersegmented_labels1

### Multi-Otsu threshold watershed

In [None]:
from skimage.filters import threshold_multiotsu

In [None]:
_, _, thresh1, thresh2 = threshold_multiotsu(labels_filtered, 5)

In [None]:
plt.hist(labels_filtered.flatten(), 100);
plt.yscale('log')
plt.plot([thresh1, thresh1], [0, 10_000], 'r', label='thresh1')
plt.plot([thresh2, thresh2], [0, 10_000], 'b', label='thresh2')
plt.legend()

In [None]:
# What is the improvement so far?

In [None]:
watershed_seeds1 = labels_filtered > thresh1

In [None]:
watershed_seed_labels1 = label(watershed_seeds1)

In [None]:
# Delete again small watershed seeds

In [None]:
props1 = regionprops(watershed_seed_labels1)

In [None]:
volumes1 = np.array([p.area for p in props1]) 
labels1 = np.array([p.label for p in props1])

In [None]:
h1 = plt.hist(volumes1, 100)
plt.plot([30, 30], [0, 70], 'r')

In [None]:
np.sum(volumes1 < 30), len(volumes1)

In [None]:
# Delete seeds which are below the volume threshold

In [None]:
lut = np.arange(watershed_seed_labels.max()+1, dtype=watershed_seed_labels1.dtype)

In [None]:
lut[labels1[volumes1 < 30]] = 0

In [None]:
watershed_seed_labels = lut[watershed_seed_labels1]

In [None]:
lut = np.zeros(ws_result.max()+1, dtype=ws_result.dtype)
lut[undersegmented_labels1] = undersegmented_labels1

In [None]:
ws_results_ = lut[ws_result]

In [None]:
ws_result1 = watershed(
    -labels_filtered,
    markers=watershed_seed_labels,
    mask=ws_results_ > 0,
)

In [None]:
ws_props2 = regionprops(ws_result1)

ws_solidity2 = np.array([p.solidity for p in ws_props2])
ws_volume2 = np.array([p.area for p in ws_props2])
ws_labels2 = np.array([p.label for p in ws_props2])

In [None]:
undersegmented_labels2 = ws_labels2[(ws_solidity2 < s_thresh) & (ws_volume2 > v_thresh)]

In [None]:
undersegmented_labels2

In [None]:
watershed_seeds2 = labels_filtered > thresh2

In [None]:
watershed_seed_labels2 = label(watershed_seeds2)

In [None]:
props2 = regionprops(watershed_seed_labels2)

In [None]:
volumes2 = np.array([p.area for p in props2]) 
labels2 = np.array([p.label for p in props2])

In [None]:
lut = np.arange(watershed_seed_labels2.max()+1, dtype=watershed_seed_labels2.dtype)

In [None]:
lut[labels2[volumes2 < 30]] = 0

In [None]:
watershed_seed_labels2 = lut[watershed_seed_labels2]

In [None]:
lut = np.zeros(ws_result1.max()+1, dtype=ws_result1.dtype)
lut[undersegmented_labels2] = undersegmented_labels2

In [None]:
ws_results1_ = lut[ws_result1]

In [None]:
ws_result2 = watershed(
    -labels_filtered,
    markers=watershed_seed_labels2,
    mask=ws_results1_ > 0,
)

In [None]:
ws_props3 = regionprops(ws_result2)

ws_solidity3 = np.array([p.solidity for p in ws_props3])
ws_volume3 = np.array([p.area for p in ws_props3])
ws_labels3 = np.array([p.label for p in ws_props3])

In [None]:
undersegmented_labels3 = ws_labels3[(ws_solidity3 < s_thresh) & (ws_volume3 > v_thresh)]

In [None]:
undersegmented_labels3

# Combine the watershed segmentations 

In [None]:
from skimage.segmentation import relabel_sequential

`target1_labels`-> Results of the direct connected components

`ws_result` -> Result of the single otsu threshold watershed

`ws_result1`-> Result of the first 5 class otsu threshold watershed

`ws_result2` ->  Result of the second 5 clss otsu threshold watershed

In [None]:
# Sanity check: If the watershed mask was applied correctly, the watershed results do not overlap with the background

In [None]:
assert not any(np.unique(target1_labels[ws_result2 > 0]) == 0)

In [None]:
assert not any(np.unique(target1_labels[ws_result1 > 0]) == 0)

In [None]:
assert not any(np.unique(target1_labels[ws_result > 0]) == 0)

In [None]:
result_labels = target1_labels

A conservative size-exclusion filter was applied: small objects with volume 10 times smaller than the upper limit volume were considered unreasonable small parts and filtered out.

In [None]:
from typing import Optional

from skimage.measure import regionprops
import numpy as np

def delete_small_objects(label_stack: np.ndarray, thresh: Optional[float] = v_thresh/10) -> np.ndarray:
    props = regionprops(label_stack)
    volumes = np.array([p.area for p in props])
    labels = np.array([p.label for p in props])

    lut = np.arange(label_stack.max()+1, dtype=label_stack.dtype)
    lut[labels[volumes < thresh]] = 0
    
    return lut[label_stack]

In [None]:
ws_result = delete_small_objects(ws_result)
ws_result1 = delete_small_objects(ws_result1)
ws_result2 = delete_small_objects(ws_result2)

In [None]:
result_labels += (result_labels.max() * (ws_result > 0)) + ws_result
result_labels += (result_labels.max() * (ws_result1 > 0)) + ws_result1
result_labels += (result_labels.max() * (ws_result2 > 0)) + ws_result2

Since the ‘distance to nearest cell exterior’ images were confined to the cell interior, we dilated each object by 1-2 voxels to increase the cell volumes using standard morphological dilation

In [None]:
result_labels = grey_dilation(result_labels, size=(3,3,3))

In [None]:
result_labels, _, _ = relabel_sequential(result_labels)

In [None]:
imwrite('bcm3d_2.0.tif', result_labels, compression='zlib')