diff --git a/src/eva/vision/data/wsi/backends/base.py b/src/eva/vision/data/wsi/backends/base.py index 66c6f323..830a213b 100644 --- a/src/eva/vision/data/wsi/backends/base.py +++ b/src/eva/vision/data/wsi/backends/base.py @@ -38,7 +38,7 @@ def level_downsamples(self) -> Sequence[float]: @property @abc.abstractmethod def mpp(self) -> float: - """Microns per pixel at the highest resolution.""" + """Microns per pixel at the highest resolution (level 0).""" @abc.abstractmethod def read_region( @@ -47,9 +47,10 @@ def read_region( """Reads and returns image data for a specified region and zoom level. Args: - location: Top-left corner (x, y) to start reading. - size: Region size as (width, height), relative to . - level: Zoom level, with 0 being the highest resolution. + location: Top-left corner (x, y) to start reading at level 0. + level: WSI level to read from. + size: Region size as (width, height) in pixels at the selected read level. + Remember to scale the size correctly. """ def get_closest_level(self, target_mpp: float) -> int: diff --git a/src/eva/vision/data/wsi/backends/openslide.py b/src/eva/vision/data/wsi/backends/openslide.py index dbfb4eea..4173b8cf 100644 --- a/src/eva/vision/data/wsi/backends/openslide.py +++ b/src/eva/vision/data/wsi/backends/openslide.py @@ -12,11 +12,11 @@ class WsiOpenslide(base.Wsi): """Class for loading data from WSI files using the OpenSlide library.""" - _wsi: openslide.OpenSlide | openslide.ImageSlide + _wsi: openslide.OpenSlide @override - def open_file(self, file_path: str) -> openslide.OpenSlide | openslide.ImageSlide: - return openslide.open_slide(file_path) + def open_file(self, file_path: str) -> openslide.OpenSlide: + return openslide.OpenSlide(file_path) @property @override @@ -40,8 +40,21 @@ def mpp(self) -> float: def read_region( self, location: Tuple[int, int], level: int, size: Tuple[int, int] ) -> np.ndarray: - x_max, y_max = self._wsi.level_dimensions[level] - if location[0] + size[0] > x_max or location[1] + size[1] > y_max: + x_max, y_max = self.level_dimensions[0] + + x_scale = x_max / self._wsi.level_dimensions[level][0] + y_scale = y_max / self._wsi.level_dimensions[level][1] + + if ( + int(location[0] + x_scale * size[0]) > x_max + or int(location[1] + y_scale * size[1]) > y_max + ): raise ValueError(f"Out of bounds region: {location}, {size}, {level}") - data = self._wsi.read_region(location, level, size) - return np.array(data.convert("RGB")) + + data = np.array(self._wsi.read_region(location, level, size)) + + if data.shape[2] == 4: + # Change color to white where the alpha channel is 0 + data[data[:, :, 3] == 0] = 255 + + return data[:, :, :3] diff --git a/src/eva/vision/data/wsi/patching/coordinates.py b/src/eva/vision/data/wsi/patching/coordinates.py index f38344b9..0600db98 100644 --- a/src/eva/vision/data/wsi/patching/coordinates.py +++ b/src/eva/vision/data/wsi/patching/coordinates.py @@ -6,7 +6,7 @@ from eva.vision.data.wsi import backends from eva.vision.data.wsi.patching import samplers -from eva.vision.utils.mask import get_mask +from eva.vision.data.wsi.patching.mask import Mask, get_mask, get_mask_level LRU_CACHE_SIZE = 32 @@ -16,16 +16,18 @@ class PatchCoordinates: """A class to store coordinates of patches from a whole-slide image. Args: - x_y: A list of (x, y) coordinates of the patches. - width: The width of the patches, in pixels (refers to x-dim). - height: The height of the patches, in pixels (refers to y-dim). - level_idx: The level index of the patches. + x_y: A list of (x, y) coordinates of the patches (refer to level 0). + width: The width of the patches, in pixels (refers to level_idx). + height: The height of the patches, in pixels (refers to level_idx). + level_idx: The level index at which to extract the patches. + mask: The foreground mask of the wsi. """ x_y: List[Tuple[int, int]] width: int height: int level_idx: int + mask: Mask | None = None @classmethod def from_file( @@ -50,24 +52,26 @@ def from_file( backend: The backend to use for reading the whole-slide images. """ wsi = backends.wsi_backend(backend)(wsi_path) - level_idx = wsi.get_closest_level(target_mpp) - level_mpp = wsi.mpp * wsi.level_downsamples[level_idx] - mpp_ratio = target_mpp / level_mpp - scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height) + # Sample patch coordinates at level 0 + mpp_ratio_0 = target_mpp / wsi.mpp sample_args = { - "width": scaled_width, - "height": scaled_height, - "layer_shape": wsi.level_dimensions[level_idx], + "width": int(mpp_ratio_0 * width), + "height": int(mpp_ratio_0 * height), + "layer_shape": wsi.level_dimensions[0], } if isinstance(sampler, samplers.ForegroundSampler): - sample_args["mask"] = get_mask(wsi, level_idx) + mask_level_idx = get_mask_level(wsi, width, height, target_mpp) + sample_args["mask"] = get_mask(wsi, mask_level_idx) + + x_y = list(sampler.sample(**sample_args)) - x_y = [] - for x, y in sampler.sample(**sample_args): - x_y.append((x, y)) + # Scale dimensions to level that is closest to the target_mpp + level_idx = wsi.get_closest_level(target_mpp) + mpp_ratio = target_mpp / (wsi.mpp * wsi.level_downsamples[level_idx]) + scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height) - return cls(x_y, scaled_width, scaled_height, level_idx) + return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask")) @functools.lru_cache(LRU_CACHE_SIZE) diff --git a/src/eva/vision/data/wsi/patching/mask.py b/src/eva/vision/data/wsi/patching/mask.py new file mode 100644 index 00000000..3dc1d9bb --- /dev/null +++ b/src/eva/vision/data/wsi/patching/mask.py @@ -0,0 +1,98 @@ +"""Functions for extracting foreground masks.""" + +import dataclasses +from typing import Tuple + +import cv2 +import numpy as np + +from eva.vision.data.wsi.backends.base import Wsi + + +@dataclasses.dataclass +class Mask: + """A class to store the mask of a whole-slide image.""" + + mask_array: np.ndarray + """Binary mask array where 1s represent the foreground and 0s represent the background.""" + + mask_level_idx: int + """WSI level index at which the mask_array was extracted.""" + + scale_factors: Tuple[float, float] + """Factors to scale x/y coordinates from mask_level_idx to level 0.""" + + +def get_mask( + wsi: Wsi, + mask_level_idx: int, + kernel_size: Tuple[int, int] = (7, 7), + gray_threshold: int = 220, + fill_holes: bool = False, +) -> Mask: + """Extracts a binary mask from an image. + + Args: + wsi: The WSI object. + mask_level_idx: The level index of the WSI at which we want to extract the mask. + kernel_size: The size of the kernel for morphological operations. + gray_threshold: The threshold for the gray scale image. + fill_holes: Whether to fill holes in the mask. + """ + image = wsi.read_region((0, 0), mask_level_idx, wsi.level_dimensions[mask_level_idx]) + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_size) + gray = np.array(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), dtype=np.uint8) + mask_array = np.where(gray < gray_threshold, 1, 0).astype(np.uint8) + + if fill_holes: + mask_array = cv2.dilate(mask_array, kernel, iterations=1) + contour, _ = cv2.findContours(mask_array, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE) + for cnt in contour: + cv2.drawContours(mask_array, [cnt], 0, (1,), -1) + + scale_factors = ( + wsi.level_dimensions[0][0] / wsi.level_dimensions[mask_level_idx][0], + wsi.level_dimensions[0][1] / wsi.level_dimensions[mask_level_idx][1], + ) + + return Mask(mask_array=mask_array, mask_level_idx=mask_level_idx, scale_factors=scale_factors) + + +def get_mask_level( + wsi: Wsi, + width: int, + height: int, + target_mpp: float, + min_mask_patch_pixels: int = 3 * 3, +) -> int: + """For performance reasons, we generate the mask at the lowest resolution level possible. + + However, if minimum resolution level has too few pixels, the patches scaled to that level will + be too small or even collapse to a single pixel. This function allows to find the lowest + resolution level that yields mask patches with at least `min_mask_patch_pixels` pixels. + + Args: + wsi: The WSI object. + width: The width of the patches to be extracted, in pixels (at target_mpp). + height: The height of the patches to be extracted, in pixels. + target_mpp: The target microns per pixel (mpp) for the patches. + min_mask_patch_pixels: The minimum number of pixels required for the mask patches. + Mask patch refers to width / height at target_mpp scaled down to the WSI level + at which the mask is generated. + """ + level_mpps = wsi.mpp * np.array(wsi.level_downsamples) + mask_level_idx = None + + for level_idx, level_mpp in reversed(list(enumerate(level_mpps))): + mpp_ratio = target_mpp / level_mpp + scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height) + + if scaled_width * scaled_height >= min_mask_patch_pixels: + mask_level_idx = level_idx + break + + if mask_level_idx is None: + raise ValueError("No level with the specified minimum number of patch pixels available.") + + return mask_level_idx diff --git a/src/eva/vision/data/wsi/patching/samplers.py b/src/eva/vision/data/wsi/patching/samplers.py index ed5f5966..58df27fd 100644 --- a/src/eva/vision/data/wsi/patching/samplers.py +++ b/src/eva/vision/data/wsi/patching/samplers.py @@ -6,6 +6,8 @@ import numpy as np +from eva.vision.data.wsi.patching.mask import Mask + class Sampler(abc.ABC): """Base class for samplers.""" @@ -16,7 +18,7 @@ def sample( width: int, height: int, layer_shape: Tuple[int, int], - mask: Tuple[np.ndarray, float] | None = None, + mask: Mask | None = None, ) -> Generator[Tuple[int, int], None, None]: """Sample patche coordinates. @@ -39,7 +41,7 @@ class ForegroundSampler(Sampler): @abc.abstractmethod def is_foreground( self, - mask: Tuple[np.ndarray, float], + mask: Mask, x: int, y: int, width: int, @@ -150,7 +152,7 @@ def sample( width: int, height: int, layer_shape: Tuple[int, int], - mask: Tuple[np.ndarray, float], + mask: Mask, ): """Sample patches from a grid containing foreground. @@ -174,7 +176,7 @@ def sample( def is_foreground( self, - mask: Tuple[np.ndarray, float], + mask: Mask, x: int, y: int, width: int, @@ -191,14 +193,13 @@ def is_foreground( height: The height of the patch. min_foreground_ratio: The minimum amount of foreground in the patch. """ - mask_array, mask_scale_factor = mask - x_, y_, width_, height_ = self._scale_coords(mask_scale_factor, x, y, width, height) - patch_mask = mask_array[y_ : y_ + height_, x_ : x_ + width_] - # TODO: look into warning "RuntimeWarning: invalid value encountered in divide" + x_, y_ = self._scale_coords(x, y, mask.scale_factors) + width_, height_ = self._scale_coords(width, height, mask.scale_factors) + patch_mask = mask.mask_array[y_ : y_ + height_, x_ : x_ + width_] return patch_mask.sum() / patch_mask.size > min_foreground_ratio - def _scale_coords(self, scale_factor, *coords): - return tuple(int(coord * scale_factor) for coord in coords) + def _scale_coords(self, x: int, y: int, scale_factors: Tuple[float, float]) -> Tuple[int, int]: + return int(x / scale_factors[0]), int(y / scale_factors[1]) def _get_grid_coords_and_indices( diff --git a/src/eva/vision/utils/mask.py b/src/eva/vision/utils/mask.py deleted file mode 100644 index 4f1cd023..00000000 --- a/src/eva/vision/utils/mask.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Functions for extracting foreground masks.""" - -from typing import Tuple - -import cv2 -import numpy as np - -from eva.vision.data.wsi.backends.base import Wsi - - -def get_mask( - wsi: Wsi, - level_idx: int, - kernel_size: Tuple[int, int] = (7, 7), - gray_threshold: int = 220, - fill_holes: bool = False, -) -> Tuple[np.ndarray, float]: - """Extracts a binary mask from an image. - - Args: - wsi: The WSI object. - level_idx: The level index to extract the mask from. - kernel_size: The size of the kernel for morphological operations. - gray_threshold: The threshold for the gray scale image. - fill_holes: Whether to fill holes in the mask. - """ - image = wsi.read_region((0, 0), len(wsi.level_dimensions) - 1, wsi.level_dimensions[-1]) - - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_size) - gray = np.array(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), dtype=np.uint8) - mask = np.where(gray < gray_threshold, 1, 0).astype(np.uint8) - - if fill_holes: - mask = cv2.dilate(mask, kernel, iterations=1) - contour, _ = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE) - for cnt in contour: - cv2.drawContours(mask, [cnt], 0, (1,), -1) - - mask_scale_factor = wsi.level_dimensions[-1][0] / wsi.level_dimensions[level_idx][0] - - return mask, mask_scale_factor