Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stain normalization with Reinhard's method #409

Merged
merged 9 commits into from Jan 14, 2023
17 changes: 11 additions & 6 deletions histolab/filters/image_filters_functional.py
Expand Up @@ -583,13 +583,18 @@ def rgb_to_lab(
np.ndarray
Array representation of the image in LAB space

Raises
------
Exception
If the image mode is not RGB
"""
if img.mode != "RGB":
raise Exception("Input image must be RGB")
if img.mode == "L":
raise ValueError("Input image must be RGB or RGBA")

if img.mode == "RGBA":
red, green, blue, _ = img.split()
img = PIL.Image.merge("RGB", (red, green, blue))

warn(
"Input image must be RGB. "
"NOTE: the image will be converted to RGB before OD conversion."
)
img_arr = np.array(img)
lab_arr = sk_color.rgb2lab(img_arr, illuminant=illuminant, observer=observer)
return lab_arr
Expand Down
107 changes: 105 additions & 2 deletions histolab/stain_normalizer.py
Expand Up @@ -22,12 +22,12 @@
# and torchstain https://github.com/EIDOSlab/torchstain


from typing import List
from typing import List, Tuple

import numpy as np
import PIL

from .filters.image_filters import RgbToOd
from .filters.image_filters import LabToRgb, RgbToLab, RgbToOd
from .masks import TissueMask
from .mixins import LinalgMixin
from .tile import Tile
Expand Down Expand Up @@ -318,3 +318,106 @@ def _ordered_stack(stain_matrix: np.ndarray, order: List[int]) -> np.ndarray:
return np.stack([stain_matrix[..., j] for j in order], -1)

return _ordered_stack(stain_matrix, _get_channel_order(stain_matrix))


class ReinhardStainNormalizer:
"""Stain normalizer using the method of E. Reinhard et al. [1]_

References
----------
.. [1] Reinhard, Erik, et al. "Color transfer between images." IEEE Computer
graphics and applications 21.5 (2001): 34-41.

"""

def __init__(self):
self.target_means = None
self.target_stds = None

def fit(self, target_rgb: PIL.Image.Image) -> None:
"""Fit stain normalizer using ``target_img``.

Parameters
----------
target_rgb : PIL.Image.Image
Target image for stain normalization. Can be either RGB or RGBA.
"""
means, stds = self._summary_statistics(target_rgb)
self.target_means = means
self.target_stds = stds

def transform(self, img_rgb: PIL.Image.Image) -> PIL.Image.Image:
"""Normalize staining of ``img_rgb``.
alessiamarcolini marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
img_rgb : PIL.Image.Image
Image to normalize. Can be either RGB or RGBA.

Returns
-------
PIL.Image.Image
Image with normalized stain.
"""
means, stds = self._summary_statistics(img_rgb)
img_lab = RgbToLab()(img_rgb)

mask = self._tissue_mask(img_rgb)
masked_img_lab = np.ma.masked_array(img_lab, ~mask)

norm_lab = (
((masked_img_lab - means) * (self.target_stds / stds)) + self.target_means
).data

for i in range(3):
original = img_lab[:, :, i].copy()
new = norm_lab[:, :, i].copy()
original[np.not_equal(~mask[:, :, 0], True)] = 0
new[~mask[:, :, 0]] = 0
norm_lab[:, :, i] = new + original

norm_rgb = LabToRgb()(norm_lab)
return norm_rgb

def _summary_statistics(
self, img_rgb: PIL.Image.Image
) -> Tuple[np.ndarray, np.ndarray]:
"""Return mean and standard deviation of each channel in LAB color space.
alessiamarcolini marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
img_rgb : PIL.Image.Image
Input image.

Returns
-------
np.ndarray
Mean of each channel in LAB color space.
np.ndarray
Standard deviation of each channel in LAB color space.
"""
mask = self._tissue_mask(img_rgb)

img_lab = RgbToLab()(img_rgb)
mean_per_channel = img_lab.mean(axis=(0, 1), where=mask)
std_per_channel = img_lab.std(axis=(0, 1), where=mask)
return mean_per_channel, std_per_channel

@staticmethod
def _tissue_mask(img_rgb: PIL.Image.Image) -> np.ndarray:
"""Return a binary mask of the tissue in ``img_rgb``.

Parameters
----------
img_rgb : PIL.Image.Image
Input image. Can be either RGB or RGBA.

Returns
-------
np.ndarray
Binary tissue mask,
alessiamarcolini marked this conversation as resolved.
Show resolved Hide resolved
"""
tile = Tile(img_rgb, None)
mask = tile.tissue_mask
mask = np.dstack((mask, mask, mask))
return mask
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 11 additions & 11 deletions tests/integration/test_image_filters.py
Expand Up @@ -1267,26 +1267,26 @@ def test_yen_threshold_filter_on_gs_image():
np.testing.assert_array_equal(yen_threshold_mask, expected_value)


def test_rgb_to_lab_filter_with_rgb_image():
@pytest.mark.parametrize(
"pil_image",
(RGBA.DIAGNOSTIC_SLIDE_THUMB, RGB.DIAGNOSTIC_SLIDE_THUMB_RGB),
)
def test_rgb_to_lab_filter_with_rgb_and_rgba_image(pil_image):
expected_value = load_expectation(
"arrays/diagnostic-slide-thumb-rgb-to-lab", type_="npy"
)

lab_img = imf.rgb_to_lab(RGB.DIAGNOSTIC_SLIDE_THUMB_RGB)
lab_img = imf.rgb_to_lab(pil_image)

np.testing.assert_array_almost_equal(lab_img, expected_value)


@pytest.mark.parametrize(
"pil_image",
(RGBA.DIAGNOSTIC_SLIDE_THUMB, GS.DIAGNOSTIC_SLIDE_THUMB_GS),
)
def test_rgb_to_lab_raises_exception_on_gs_and_rgba_image(pil_image):
with pytest.raises(Exception) as err:
imf.rgb_to_lab(pil_image)
def test_rgb_to_lab_raises_exception_on_gs_image():
with pytest.raises(ValueError) as err:
imf.rgb_to_lab(GS.DIAGNOSTIC_SLIDE_THUMB_GS)

assert isinstance(err.value, Exception)
assert str(err.value) == "Input image must be RGB"
assert isinstance(err.value, ValueError)
assert str(err.value) == "Input image must be RGB or RGBA"


def test_dab_channel_filter_with_rgb_image():
Expand Down
98 changes: 97 additions & 1 deletion tests/integration/test_stain_normalizer.py
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from histolab.stain_normalizer import MacenkoStainNormalizer
from histolab.stain_normalizer import MacenkoStainNormalizer, ReinhardStainNormalizer

from ..fixtures import TILES
from ..util import load_expectation
Expand Down Expand Up @@ -135,3 +135,99 @@ def it_knows_how_to_fit_and_transform(
np.testing.assert_array_almost_equal(
np.array(img_normalized), np.array(expected_img_normalized)
)


class Describe_ReinhardStainNormalizer:
@pytest.mark.parametrize(
"img, expected_mean, expected_std",
[
(
TILES.TISSUE_LEVEL0_7352_11762_7864_12274,
np.array([57.77235476, 30.6667066, -12.3231239]),
np.array([15.37156544, 10.19152132, 5.33366456]),
),
(
TILES.MEDIUM_NUCLEI_SCORE_LEVEL1,
np.array([63.60525855, 34.10860319, -3.26439523]),
np.array([13.75089992, 13.89511025, 6.91485312]),
),
(
TILES.LOW_NUCLEI_SCORE_LEVEL0,
np.array([82.67172754, 8.5236446, -4.30401803]),
np.array([12.48770309, 8.40158118, 10.3241243]),
),
],
)
def it_knows_its_summary_statistics(self, img, expected_mean, expected_std):
normalizer = ReinhardStainNormalizer()

mean, std = normalizer._summary_statistics(img)

np.testing.assert_almost_equal(mean, expected_mean)
np.testing.assert_almost_equal(std, expected_std)

@pytest.mark.parametrize(
"img, expected_mean, expected_std",
[
(
TILES.TISSUE_LEVEL0_7352_11762_7864_12274,
np.array([57.77235476, 30.6667066, -12.3231239]),
np.array([15.37156544, 10.19152132, 5.33366456]),
),
(
TILES.MEDIUM_NUCLEI_SCORE_LEVEL1,
np.array([63.60525855, 34.10860319, -3.26439523]),
np.array([13.75089992, 13.89511025, 6.91485312]),
),
(
TILES.LOW_NUCLEI_SCORE_LEVEL0,
np.array([82.67172754, 8.5236446, -4.30401803]),
np.array([12.48770309, 8.40158118, 10.3241243]),
),
],
)
def it_knows_how_to_fit(self, img, expected_mean, expected_std):
normalizer = ReinhardStainNormalizer()

assert normalizer.target_means is None
assert normalizer.target_stds is None

normalizer.fit(img)

target_means = normalizer.target_means
target_stds = normalizer.target_stds

np.testing.assert_almost_equal(target_means, expected_mean)
np.testing.assert_almost_equal(target_stds, expected_std)

@pytest.mark.parametrize(
"img_to_fit, img_to_transform, expected_img_normalized_path",
[
(
TILES.TISSUE_LEVEL0_7352_11762_7864_12274,
TILES.MEDIUM_NUCLEI_SCORE_LEVEL1,
"pil-images-rgb/tissue-level0-7352-11762-7864-12274"
"--medium-nuclei-score-level1--reinhard",
),
(
TILES.TISSUE_LEVEL0_7352_11762_7864_12274,
TILES.LOW_NUCLEI_SCORE_LEVEL0,
"pil-images-rgb/tissue-level0-7352-11762-7864-12274"
"--low-nuclei-score-level0--reinhard",
),
],
)
def it_knows_how_to_fit_and_transform_Re(
self, img_to_fit, img_to_transform, expected_img_normalized_path
):
expected_img_normalized = load_expectation(
expected_img_normalized_path, type_="png"
)
normalizer = ReinhardStainNormalizer()

normalizer.fit(img_to_fit)
img_normalized = normalizer.transform(img_to_transform)

np.testing.assert_almost_equal(
np.array(img_normalized), np.array(expected_img_normalized)
)