diff --git a/pyproject.toml b/pyproject.toml index 6057734..c3c6087 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,8 @@ test = [ ] quality = [ "ruff==0.4.1", - "mypy==1.8.0", + "mypy==1.10.0", + "types-Pillow", "pre-commit>=3.0.0,<4.0.0", ] docs = [ @@ -77,7 +78,8 @@ dev = [ "pytest-pretty>=1.0.0,<2.0.0", # style "ruff==0.4.1", - "mypy==1.8.0", + "mypy==1.10.0", + "types-Pillow", "pre-commit>=3.0.0,<4.0.0", # docs "sphinx>=3.0.0,!=3.5.0", diff --git a/torchcam/methods/core.py b/torchcam/methods/core.py index 8d12d69..5c1ad60 100644 --- a/torchcam/methods/core.py +++ b/torchcam/methods/core.py @@ -7,7 +7,7 @@ from abc import abstractmethod from functools import partial from types import TracebackType -from typing import Any, List, Optional, Tuple, Type, Union, cast +from typing import Any, List, Optional, Tuple, Type, Union import torch import torch.nn.functional as F @@ -267,4 +267,4 @@ def _fuse_cams(cams: List[Tensor], target_shape: Tuple[int, int]) -> Tensor: ] # Fuse them - return cast(Tensor, torch.stack(scaled_cams).max(dim=0).values.squeeze(1)) + return torch.stack(scaled_cams).max(dim=0).values.squeeze(1) diff --git a/torchcam/methods/gradient.py b/torchcam/methods/gradient.py index 90d0f77..260bfe8 100644 --- a/torchcam/methods/gradient.py +++ b/torchcam/methods/gradient.py @@ -4,7 +4,7 @@ # See LICENSE or go to for full license details. from functools import partial -from typing import Any, List, Optional, Tuple, Union, cast +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor, nn @@ -404,4 +404,4 @@ def _get_weights(self, class_idx: Union[int, List[int]], scores: Tensor, **kwarg @staticmethod def _scale_cams(cams: List[Tensor], gamma: float = 2.0) -> List[Tensor]: # cf. Equation 9 in the paper - return [torch.tanh(cast(Tensor, gamma * cam)) for cam in cams] + return [torch.tanh(gamma * cam) for cam in cams] diff --git a/torchcam/utils.py b/torchcam/utils.py index 17f79ae..24b7419 100644 --- a/torchcam/utils.py +++ b/torchcam/utils.py @@ -3,12 +3,14 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +from typing import cast + import numpy as np from matplotlib import colormaps as cm -from PIL import Image +from PIL.Image import Image, Resampling, fromarray -def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = "jet", alpha: float = 0.7) -> Image.Image: +def overlay_mask(img: Image, mask: Image, colormap: str = "jet", alpha: float = 0.7) -> Image: """Overlay a colormapped mask on a background image >>> from PIL import Image @@ -31,7 +33,7 @@ def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = "jet", alp TypeError: when the arguments have invalid types ValueError: when the alpha argument has an incorrect value """ - if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image): + if not isinstance(img, Image) or not isinstance(mask, Image): raise TypeError("img and mask arguments need to be PIL.Image") if not isinstance(alpha, float) or alpha < 0 or alpha >= 1: @@ -39,9 +41,7 @@ def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = "jet", alp cmap = cm.get_cmap(colormap) # Resize mask and apply colormap - overlay = mask.resize(img.size, resample=Image.BICUBIC) + overlay = mask.resize(img.size, resample=Resampling.BICUBIC) overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8) # Overlay the image with the mask - overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8)) - - return overlayed_img + return fromarray((alpha * np.asarray(img) + (1 - alpha) * cast(np.ndarray, overlay)).astype(np.uint8))