In [1]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Imports

In [2]:
import tifffile
import numpy as np
import matplotlib.pyplot as plt
import os

from histomicstk.preprocessing.color_normalization import (
    deconvolution_based_normalization,
)
from histomicstk.saliency.tissue_detection import get_tissue_mask
from skimage.transform import resize
from skimage.exposure import rescale_intensity
from stardist.models import StarDist2D

from tiatoolbox.tools import stainnorm
from typing import List, Tuple

bioimageio_utils.py (2): pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.

objc[45845]: Class GNotificationCenterDelegate is implemented in both /opt/anaconda3/envs/research_project/lib/libgio-2.0.0.dylib (0x31ae386d8) and /opt/anaconda3/envs/research_project/lib/python3.12/site-packages/openslide_bin/libopenslide.1.dylib (0x377ad9318). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.



# Setup

In [3]:
image_paths = sorted(
    [
        project_root / "data/images" / path
        for path in os.listdir(project_root / "data/images")
        if not path.startswith(".")
    ]
)
images = list(map(tifffile.imread, image_paths))

normalized_image_paths = sorted(
    [
        project_root / "data/normalized_images" / path
        for path in os.listdir(project_root / "data/normalized_images")
        if not path.startswith(".")
    ]
)
normalized_images = list(map(tifffile.imread, normalized_image_paths))

mask_paths = sorted(
    [
        project_root / "data/masks" / path
        for path in os.listdir(project_root / "data/masks")
        if not path.startswith(".")
    ]
)
masks = list(map(tifffile.imread, mask_paths))
data = list(zip(images, masks))

# Find Reference Image

In [4]:
def calculate_histogram_separation(
    image: np.ndarray, mask: np.ndarray, model: StarDist2D
) -> float:

    image_normed = rescale_intensity(image, out_range=(0, 1))
    labels, data_dict = model.predict_instances(
        image_normed, axes="YXC", prob_thresh=0.05, nms_thresh=0.3, return_labels=True
    )

    nuclei_mask = labels > 0
    tissue_mask = mask & ~nuclei_mask

    nuclei_intensities = image[nuclei_mask]
    tissue_intensities = image[tissue_mask]

    if len(nuclei_intensities) < 100 or len(tissue_intensities) < 100:
        return 0.0

    mean_nuclei = np.mean(nuclei_intensities)
    mean_tissue = np.mean(tissue_intensities)
    std_nuclei = np.std(nuclei_intensities)
    std_tissue = np.std(tissue_intensities)

    # Use between-class variance
    w_nuclei = len(nuclei_intensities) / (
        len(nuclei_intensities) + len(tissue_intensities)
    )
    w_tissue = 1 - w_nuclei

    separation_score = w_nuclei * w_tissue * (mean_nuclei - mean_tissue) ** 2

    return separation_score

In [None]:
def find_reference_image(
    data: List[Tuple[np.ndarray, np.ndarray]], model: StarDist2D
) -> Tuple[int, np.ndarray]:

    print("Analyzing histogram separation for all images...")

    max_separation_score = 0
    reference_idx = 0
    reference_image = images[0]

    for i, (image, mask) in enumerate(data):
        separation_score = calculate_histogram_separation(image, mask, model)
        if separation_score > max_separation_score:
            max_separation_score = separation_score
            reference_idx = i
            reference_image = image

    print(
        f"\nReference image selected: Image {reference_idx} with score {max_separation_score:.4f}"
    )

    return reference_idx, reference_image

In [12]:
model = StarDist2D.from_pretrained("2D_versatile_he")

Found model '2D_versatile_he' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.692478, nms_thresh=0.3.


In [19]:
ref_idx, ref_img = find_reference_image(data, model)

Analyzing histogram separation for all images...

Reference image selected: Image 9 with score 634.4435


In [5]:
# plt.imshow(ref_img)
# plt.axis(False)
# plt.show()

In [21]:
ref_path = image_paths[ref_idx]
ref_path

PosixPath('/Users/levin/Documents/Uni/Master/semester_3/research_project/ASON/data/images/E2+P4+DHT_1_M7_3L_0013.tif')

# Test Macenko and Vahadane Normalization with Reference Image

In [22]:
vah_norm = stainnorm.VahadaneNormalizer()
vah_norm.fit(ref_img)



In [23]:
mac_norm = stainnorm.MacenkoNormalizer()
mac_norm.fit(ref_img)

In [97]:
# for img in images:
#     img_norm_vah = vah_norm.transform(img.copy())
#     img_norm_mac = mac_norm.transform(img.copy())
#     compare_two_images(img, img_norm_vah, "Original Image", "Vahadane Normalized Image")
#     compare_two_images(img, img_norm_mac, "Original Image", "Macenko Normalized Image")

# Test Macenko Normalization for Segmentation

In [100]:
# for image, mask in data:
#     image_norm = mac_norm.transform(image.copy())
#     image_resc = rescale_intensity(image, out_range=(0, 1))
#     image_norm_resc = rescale_intensity(image_norm, out_range=(0, 1))
#     labels, _ = model.predict_instances(image_resc, axes='YXC', prob_thresh=0.05, nms_thresh=0.3, return_labels=True)
#     labels_norm, _ = model.predict_instances(image_norm_resc, axes='YXC', prob_thresh=0.05, nms_thresh=0.3, return_labels=True)
#     labels = cut_out_image(labels, mask)
#     labels_norm = cut_out_image(labels_norm, mask)
#     compare_two_images(render_label(labels, img=image_resc, cmap=(1.0, 1.0, 0), alpha=0.6), render_label(labels_norm, img=image_norm_resc, cmap=(1.0, 1.0, 0), alpha=0.6))

Macenko normalization does not improve the segmentation quality, in some cases it even reduces the performance. For now, do not apply macenko normalization but continue with the original images.

# Test Reinhard Normalization

In [87]:
reinhard_norm = stainnorm.ReinhardNormalizer()
reinhard_norm.fit(ref_img)

In [98]:
# for img in images:
#     img_norm = reinhard_norm.transform(img.copy())
#     compare_two_images(img, img_norm, "Original Image", "Reinhard Normalized Image")

In [99]:
# non_norm_num_objects = []
# norm_num_objects = []

# for image, mask in data:
#     image_norm = reinhard_norm.transform(image.copy())
#     image_resc = rescale_intensity(image, out_range=(0, 1))
#     image_norm_resc = rescale_intensity(image_norm, out_range=(0, 1))
#     labels, seg_dict = model.predict_instances(image_resc, axes='YXC', prob_thresh=0.05, nms_thresh=0.3, return_labels=True)
#     labels_norm, seg_dict_norm = model.predict_instances(image_norm_resc, axes='YXC', prob_thresh=0.05, nms_thresh=0.3, return_labels=True)
#     labels = cut_out_image(labels, mask)
#     labels_norm = cut_out_image(labels_norm, mask)
#     non_norm_num_objects.append(len(seg_dict['prob']))
#     norm_num_objects.append(len(seg_dict_norm['prob']))
#     compare_two_images(render_label(labels, img=image_resc, cmap=(1.0, 1.0, 0), alpha=0.6), render_label(labels_norm, img=image_norm_resc, cmap=(1.0, 1.0, 0), alpha=0.6))

In [93]:
print(
    f"Average number of detected nuclei in non-normalized images: {np.mean(non_norm_num_objects)}"
)
print(
    f"Average number of detected nuclei in normalized images: {np.mean(norm_num_objects)}"
)

Average number of detected nuclei in non-normalized images: 1005.85
Average number of detected nuclei in normalized images: 1179.6


Reihnard Normalization is overshooting the colors and contrast as well, so continue without normalization.

# --- OLD NORMALIZATION EXPERIMENTS ---

### Test Reference Image 1

In [None]:
# Only for visualization in presentation

# for i, img in enumerate(images):
#     img_norm_vah = stain_normalizer_2.transform(img.copy())
#     img_norm_mac = normalized_images[i]
#     plt.figure(figsize=(18, 6))

#     plt.subplot(1, 3, 1)
#     plt.axis(False)
#     plt.title("Original Image")
#     plt.imshow(img, cmap='gray')

#     plt.subplot(1, 3, 2)
#     plt.axis(False)
#     plt.title("Macenko Normalized Image")
#     plt.imshow(img_norm_mac)

#     plt.subplot(1, 3, 3)
#     plt.axis(False)
#     plt.title("Vahadane Normalized Image")
#     plt.imshow(img_norm_vah)

#     plt.show()

In [29]:
# for img in images:
#     img_norm = stain_normalizer_1.transform(img.copy())
#     compare_two_images(img, img_norm, "Original Image", "Normalized Image (Ref1)")

##### Compare segmentation performance with non-normalized and Ref1-normalization

In [16]:
from stardist.models import StarDist2D

In [13]:
model = StarDist2D.from_pretrained("2D_versatile_he")

Found model '2D_versatile_he' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.692478, nms_thresh=0.3.


In [14]:
mask_paths = sorted(
    [
        "data/masks/" + path
        for path in os.listdir("data/masks")
        if not path.startswith(".")
    ]
)
masks = list(map(tifffile.imread, mask_paths))
data = list(zip(images, masks))

In [30]:
# for image, mask in data:
#     image_norm = stain_normalizer_1.transform(image.copy())
#     image_resc = rescale_intensity(image, out_range=(0, 1))
#     image_norm_resc = rescale_intensity(image_norm, out_range=(0, 1))
#     labels, _ = model.predict_instances(image_resc, axes='YXC', prob_thresh=0.05, nms_thresh=0.3, return_labels=True)
#     labels_norm, _ = model.predict_instances(image_norm_resc, axes='YXC', prob_thresh=0.05, nms_thresh=0.3, return_labels=True)
#     labels = cut_out_image(labels, mask)
#     labels_norm = cut_out_image(labels_norm, mask)
#     compare_two_images(render_label(labels, img=image_resc, cmap=(1.0, 1.0, 0), alpha=0.6), render_label(labels_norm, img=image_norm_resc, cmap=(1.0, 1.0, 0), alpha=0.6))

### Test Reference Image 2

In [32]:
# for img in images:
#     img_norm = stain_normalizer_2.transform(img.copy())
#     compare_two_images(img, img_norm, "Original Image", "Normalized Image (Ref2)")

In [31]:
# for image, mask in data:
#     image_norm = stain_normalizer_2.transform(image.copy())
#     image_resc = rescale_intensity(image, out_range=(0, 1))
#     image_norm_resc = rescale_intensity(image_norm, out_range=(0, 1))
#     labels, _ = model.predict_instances(image_resc, axes='YXC', prob_thresh=0.05, nms_thresh=0.3, return_labels=True)
#     labels_norm, _ = model.predict_instances(image_norm_resc, axes='YXC', prob_thresh=0.05, nms_thresh=0.3, return_labels=True)
#     labels = cut_out_image(labels, mask)
#     labels_norm = cut_out_image(labels_norm, mask)
#     compare_two_images(render_label(labels, img=image_resc, cmap=(1.0, 1.0, 0), alpha=0.6), render_label(labels_norm, img=image_norm_resc, cmap=(1.0, 1.0, 0), alpha=0.6))

In [29]:
def create_mask(img):
    mask_out, _ = get_tissue_mask(
        img, deconvolve_first=True, n_thresholding_steps=1, sigma=1.5, min_size=30
    )

    mask_out_fixed = (
        resize(mask_out == 0, output_shape=img.shape[:2], order=0, preserve_range=True)
        == 1
    )

    return ~mask_out_fixed

In [20]:
reference_mask_1 = create_mask(reference_image_1)
reference_mask_2 = create_mask(reference_image_2)

In [38]:
# cut_out_reference_1 = cut_out_image(reference_image_1, reference_mask_1)
# cut_out_reference_2 = cut_out_image(reference_image_2, reference_mask_2)
# compare_two_images(cut_out_reference_1, cut_out_reference_2)

In [22]:
def color_normalization_macenko(
    target_img,
    reference_img,
    mask=None,
    stains=["hematoxylin", "eosin"],
    stain_unmixing_method="macenko_pca",
):
    stain_unmixing_routine_params = {
        "stains": stains,
        "stain_unmixing_method": stain_unmixing_method,
    }

    normalized_target_img = deconvolution_based_normalization(
        target_img,
        im_target=reference_img,
        mask_out=mask,
        stain_unmixing_routine_params=stain_unmixing_routine_params,
    )

    return normalized_target_img

In [32]:
masks = [create_mask(img) for img in images]

In [24]:
# compare_two_images(images[5], masks[5])

### Test Reference Image 1 without cutout

In [79]:
normalized_images_1 = [
    color_normalization_macenko(image, reference_image_1, masks[i])
    for i, image in enumerate(images)
]

In [92]:
# for img, norm_img in zip(images, normalized_images_1):
#     plt.figure(figsize=(12, 6))

#     plt.subplot(1, 2, 1)
#     plt.axis(False)
#     plt.title("Original image")
#     plt.imshow(img)

#     plt.subplot(1, 2, 2)
#     plt.axis(False)
#     plt.title("Normalized image")
#     plt.imshow(norm_img)

#     plt.show()

### Test reference image 2 without cutout

In [35]:
normalized_images_2 = [
    color_normalization_macenko(image, reference_image_2, masks[i])
    for i, image in enumerate(images)
]

In [37]:
# for img, norm_img in zip(images, normalized_images_2):
#     plt.figure(figsize=(12, 6))

#     plt.subplot(1, 2, 1)
#     plt.axis(False)
#     plt.title("Original image")
#     plt.imshow(img)

#     plt.subplot(1, 2, 2)
#     plt.axis(False)
#     plt.title("Normalized image")
#     plt.imshow(norm_img)

#     plt.show()

### Test reference image 1 with cut out

In [83]:
normalized_images_1_cut = [
    color_normalization_macenko(image, cut_out_reference_1, masks[i])
    for i, image in enumerate(images)
]

In [94]:
# for img, norm_img in zip(images, normalized_images_1_cut):
#     plt.figure(figsize=(12, 6))

#     plt.subplot(1, 2, 1)
#     plt.axis(False)
#     plt.title("Original image")
#     plt.imshow(img)

#     plt.subplot(1, 2, 2)
#     plt.axis(False)
#     plt.title("Normalized image")
#     plt.imshow(norm_img)

#     plt.show()

### Test reference 2 with cutout

In [85]:
normalized_images_2_cut = [
    color_normalization_macenko(image, cut_out_reference_2, masks[i])
    for i, image in enumerate(images)
]

In [95]:
# for img, norm_img in zip(images, normalized_images_2_cut):
#     plt.figure(figsize=(12, 6))

#     plt.subplot(1, 2, 1)
#     plt.axis(False)
#     plt.title("Original image")
#     plt.imshow(img)

#     plt.subplot(1, 2, 2)
#     plt.axis(False)
#     plt.title("Normalized image")
#     plt.imshow(norm_img)

#     plt.show()

Reference image 2 without cutout seems to provide the best results!