Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lower bound for wsi resolution level during mask generation #412

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 9 additions & 9 deletions src/eva/core/models/modules/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from eva.core.metrics import structs as metrics_lib
from eva.core.models.modules import module
from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
from eva.core.models.modules.typings import DATA_SAMPLE, MODEL_TYPE
from eva.core.models.modules.utils import batch_postprocess, grad


Expand Down Expand Up @@ -72,23 +72,23 @@ def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tens
return self.head(features).squeeze(-1)

@override
def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
def training_step(self, batch: DATA_SAMPLE, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
return self._batch_step(batch)

@override
def validation_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
def validation_step(self, batch: DATA_SAMPLE, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
return self._batch_step(batch)

@override
def test_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
def test_step(self, batch: DATA_SAMPLE, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
return self._batch_step(batch)

@override
def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
tensor = INPUT_BATCH(*batch).data
def predict_step(self, batch: DATA_SAMPLE, *args: Any, **kwargs: Any) -> torch.Tensor:
tensor = DATA_SAMPLE(*batch).data
return tensor if self.backbone is None else self.backbone(tensor)

def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT:
def _batch_step(self, batch: DATA_SAMPLE) -> STEP_OUTPUT:
"""Performs a model forward step and calculates the loss.

Args:
Expand All @@ -97,12 +97,12 @@ def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT:
Returns:
The batch step output.
"""
data, targets, metadata = INPUT_BATCH(*batch)
data, targets, metadata = DATA_SAMPLE(*batch)
predictions = self(data)
loss = self.criterion(predictions, targets)
return {
"loss": loss,
"targets": targets,
"predictions": predictions,
"metadata": metadata,
}
}
10 changes: 7 additions & 3 deletions src/eva/vision/data/wsi/patching/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from eva.vision.data.wsi import backends
from eva.vision.data.wsi.patching import samplers
from eva.vision.utils.mask import get_mask
from eva.vision.data.wsi.patching.mask import Mask, get_mask

LRU_CACHE_SIZE = 32

Expand All @@ -20,12 +20,14 @@ class PatchCoordinates:
width: The width of the patches, in pixels (refers to x-dim).
height: The height of the patches, in pixels (refers to y-dim).
level_idx: The level index of the patches.
mask: The foreground mask of the wsi.
"""

x_y: List[Tuple[int, int]]
width: int
height: int
level_idx: int
mask: Mask | None = None

@classmethod
def from_file(
Expand Down Expand Up @@ -54,20 +56,22 @@ def from_file(
level_mpp = wsi.mpp * wsi.level_downsamples[level_idx]
mpp_ratio = target_mpp / level_mpp
scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)
mask = None

sample_args = {
"width": scaled_width,
"height": scaled_height,
"layer_shape": wsi.level_dimensions[level_idx],
}
if isinstance(sampler, samplers.ForegroundSampler):
sample_args["mask"] = get_mask(wsi, level_idx)
mask = get_mask(wsi, level_idx)
sample_args["mask"] = mask

x_y = []
for x, y in sampler.sample(**sample_args):
x_y.append((x, y))

return cls(x_y, scaled_width, scaled_height, level_idx)
return cls(x_y, scaled_width, scaled_height, level_idx, mask)


@functools.lru_cache(LRU_CACHE_SIZE)
Expand Down
78 changes: 78 additions & 0 deletions src/eva/vision/data/wsi/patching/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Functions for extracting foreground masks."""

import dataclasses
from typing import Tuple

import cv2
import numpy as np

from eva.vision.data.wsi.backends.base import Wsi


@dataclasses.dataclass
class Mask:
"""A class to store the mask of a whole-slide image."""

mask_array: np.ndarray
"""Binary mask array where 1s represent the foreground and 0s represent the background."""

scale_factor: float
"""Factor to scale mask to the wsi coordinates."""


def get_mask(
wsi: Wsi,
level_idx: int,
kernel_size: Tuple[int, int] = (7, 7),
gray_threshold: int = 220,
fill_holes: bool = False,
) -> Mask:
"""Extracts a binary mask from an image.

Args:
wsi: The WSI object.
level_idx: The level index of the WSI at which we specify the coordinates.
kernel_size: The size of the kernel for morphological operations.
gray_threshold: The threshold for the gray scale image.
fill_holes: Whether to fill holes in the mask.
"""
low_res_level = get_lowest_resolution_level(wsi, min_pixels=1000 * 1000)
image = wsi.read_region((0, 0), low_res_level, wsi.level_dimensions[low_res_level])

kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_size)
gray = np.array(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), dtype=np.uint8)
mask = np.where(gray < gray_threshold, 1, 0).astype(np.uint8)

if fill_holes:
mask = cv2.dilate(mask, kernel, iterations=1)
contour, _ = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contour:
cv2.drawContours(mask, [cnt], 0, (1,), -1)

mask_scale_factor = wsi.level_dimensions[low_res_level][0] / wsi.level_dimensions[level_idx][0]

return Mask(mask_array=mask, scale_factor=mask_scale_factor)


def get_lowest_resolution_level(wsi: Wsi, min_pixels: int | None):
"""Calculates the WSI level corresponding to the lowest resolution/magnification.

Args:
wsi: The WSI object.
min_pixels: If specified, this funciton will return the lowest resolution
level with an area of at least `min_pixels` pixels.

Returns:
The lowest resolution level index of the given WSI.
"""
valid_level_index = len(wsi.level_dimensions) - 1

if min_pixels is None:
return valid_level_index
else:
for index, (width, height) in reversed(list(enumerate(wsi.level_dimensions))):
if width * height >= min_pixels:
valid_level_index = index
break
nkaenzig marked this conversation as resolved.
Show resolved Hide resolved

return valid_level_index
15 changes: 8 additions & 7 deletions src/eva/vision/data/wsi/patching/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np

from eva.vision.data.wsi.patching.mask import Mask


class Sampler(abc.ABC):
"""Base class for samplers."""
Expand All @@ -16,7 +18,7 @@ def sample(
width: int,
height: int,
layer_shape: Tuple[int, int],
mask: Tuple[np.ndarray, float] | None = None,
mask: Mask | None = None,
) -> Generator[Tuple[int, int], None, None]:
"""Sample patche coordinates.

Expand All @@ -39,7 +41,7 @@ class ForegroundSampler(Sampler):
@abc.abstractmethod
def is_foreground(
self,
mask: Tuple[np.ndarray, float],
mask: Mask,
x: int,
y: int,
width: int,
Expand Down Expand Up @@ -150,7 +152,7 @@ def sample(
width: int,
height: int,
layer_shape: Tuple[int, int],
mask: Tuple[np.ndarray, float],
mask: Mask,
):
"""Sample patches from a grid containing foreground.

Expand All @@ -174,7 +176,7 @@ def sample(

def is_foreground(
self,
mask: Tuple[np.ndarray, float],
mask: Mask,
x: int,
y: int,
width: int,
Expand All @@ -191,9 +193,8 @@ def is_foreground(
height: The height of the patch.
min_foreground_ratio: The minimum amount of foreground in the patch.
"""
mask_array, mask_scale_factor = mask
x_, y_, width_, height_ = self._scale_coords(mask_scale_factor, x, y, width, height)
patch_mask = mask_array[y_ : y_ + height_, x_ : x_ + width_]
x_, y_, width_, height_ = self._scale_coords(mask.scale_factor, x, y, width, height)
patch_mask = mask.mask_array[y_ : y_ + height_, x_ : x_ + width_]
# TODO: look into warning "RuntimeWarning: invalid value encountered in divide"
return patch_mask.sum() / patch_mask.size > min_foreground_ratio

Expand Down
41 changes: 0 additions & 41 deletions src/eva/vision/utils/mask.py

This file was deleted.

Loading