From 25a4d70a1bceeb721ca6cd65911b7a83d6c2abb0 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Fri, 8 Apr 2022 14:29:42 -0600 Subject: [PATCH 1/4] Correct `RandomScale` --- captum/optim/_param/image/transforms.py | 256 +++++++++++++--- captum/optim/_utils/typing.py | 14 +- tests/optim/param/test_transforms.py | 385 ++++++++++++++++++++++-- 3 files changed, 596 insertions(+), 59 deletions(-) diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index df5ff15f7e..c770f3456b 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from captum.optim._utils.image.common import nchannels_to_rgb -from captum.optim._utils.typing import IntSeqOrIntType, NumSeqOrTensorType +from captum.optim._utils.typing import IntSeqOrIntType, NumSeqOrTensorOrProbDistType class BlendAlpha(nn.Module): @@ -273,73 +273,255 @@ def center_crop( return x -def _rand_select( - transform_values: NumSeqOrTensorType, -) -> Union[int, float, torch.Tensor]: +class RandomScale(nn.Module): + """ + Apply random rescaling on a NCHW tensor using the F.interpolate function. """ - Randomly return a single value from the provided tuple, list, or tensor. - Args: + __constants__ = [ + "scale", + "mode", + "align_corners", + "_has_align_corners", + "recompute_scale_factor", + "_has_recompute_scale_factor", + "_is_distribution", + ] - transform_values (sequence): A sequence of values to randomly select from. + def __init__( + self, + scale: NumSeqOrTensorOrProbDistType, + mode: str = "bilinear", + align_corners: Optional[bool] = False, + recompute_scale_factor: bool = False, + ) -> None: + """ + Args: + scale (float, sequence, or torch.distribution): Sequence of rescaling + values to randomly select from, or a torch.distributions instance. + mode (str, optional): Interpolation mode to use. See documentation of + F.interpolate for more details. One of; "bilinear", "nearest", "area", + or "bicubic". + Default: "bilinear" + align_corners (bool, optional): Whether or not to align corners. See + documentation of F.interpolate for more details. + Default: False + recompute_scale_factor (bool, optional): Whether or not to recompute the + scale factor See documentation of F.interpolate for more details. + Default: False + """ + super().__init__() + assert mode not in ["linear", "trilinear"] + if isinstance(scale, torch.distributions.distribution.Distribution): + # Distributions are not supported by TorchScript / JIT yet + assert scale.batch_shape == torch.Size([]) + self.scale_distribution = scale + self._is_distribution = True + self.scale = [] + else: + assert hasattr(scale, "__iter__") + if torch.is_tensor(scale): + assert cast(torch.Tensor, scale).dim() == 1 + scale = scale.tolist() + assert len(scale) > 0 + self.scale = [float(s) for s in scale] + self._is_distribution = False + self.mode = mode + self.align_corners = align_corners if mode not in ["nearest", "area"] else None + self.recompute_scale_factor = recompute_scale_factor + self._has_align_corners = torch.__version__ >= "1.3.0" + self._has_recompute_scale_factor = torch.__version__ >= "1.6.0" - Returns: - **value**: A single value from the specified sequence. - """ - n = torch.randint(low=0, high=len(transform_values), size=[1]).item() - return transform_values[n] + def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor: + """ + Scale an NCHW image tensor based on a specified scale value. + Args: + x (torch.Tensor): The NCHW image tensor to scale. + scale (float): The amount to scale the NCHW image by. + Returns: + **x** (torch.Tensor): A scaled NCHW image tensor. + """ + if self._has_align_corners: + if self._has_recompute_scale_factor: + x = F.interpolate( + x, + scale_factor=scale, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + ) + else: + x = F.interpolate( + x, + scale_factor=scale, + mode=self.mode, + align_corners=self.align_corners, + ) + else: + x = F.interpolate(x, scale_factor=scale, mode=self.mode) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Randomly scale an NCHW image tensor. + Args: + x (torch.Tensor): NCHW image tensor to randomly scale. + Returns: + **x** (torch.Tensor): A randomly scaled NCHW image *tensor*. + """ + assert x.dim() == 4 + if self._is_distribution: + scale = float(self.scale_distribution.sample().item()) + else: + n = int( + torch.randint( + low=0, + high=len(self.scale), + size=[1], + dtype=torch.int64, + layout=torch.strided, + device=x.device, + ).item() + ) + scale = self.scale[n] + return self._scale_tensor(x, scale=scale) -class RandomScale(nn.Module): +class RandomScaleAffine(nn.Module): """ Apply random rescaling on a NCHW tensor. + This random scaling transform utilizes F.affine_grid & F.grid_sample, and as a + result has two key differences to the default RandomScale transforms This + transform either shrinks an image while adding a background, or center crops image + and then resizes it to a larger size. This means that the output image shape is the + same shape as the input image. + In constrast to RandomScaleAffine, the default RandomScale transform simply resizes + the input image using F.interpolate. """ - def __init__(self, scale: NumSeqOrTensorType) -> None: + __constants__ = [ + "scale", + "mode", + "padding_mode", + "align_corners", + "_has_align_corners", + "_is_distribution", + ] + + def __init__( + self, + scale: NumSeqOrTensorOrProbDistType, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: bool = False, + ) -> None: """ Args: - - scale (float, sequence): Tuple of rescaling values to randomly select from. + scale (float, sequence, or torch.distribution): Sequence of rescaling + values to randomly select from, or a torch.distributions instance. + mode (str, optional): Interpolation mode to use. See documentation of + F.grid_sample for more details. One of; "bilinear", "nearest", or + "bicubic". + Default: "bilinear" + padding_mode (str, optional): Padding mode for values that fall outside of + the grid. See documentation of F.grid_sample for more details. One of; + "zeros", "border", or "reflection". + Default: "zeros" + align_corners (bool, optional): Whether or not to align corners. See + documentation of F.affine_grid & F.grid_sample for more details. + Default: False """ super().__init__() - self.scale = scale + if isinstance(scale, torch.distributions.distribution.Distribution): + # Distributions are not supported by TorchScript / JIT yet + assert scale.batch_shape == torch.Size([]) + self.scale_distribution = scale + self._is_distribution = True + self.scale = [] + else: + assert hasattr(scale, "__iter__") + if torch.is_tensor(scale): + assert cast(torch.Tensor, scale).dim() == 1 + scale = scale.tolist() + assert len(scale) > 0 + self.scale = [float(s) for s in scale] + self._is_distribution = False + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self._has_align_corners = torch.__version__ >= "1.3.0" - def get_scale_mat( - self, m: IntSeqOrIntType, device: torch.device, dtype: torch.dtype + def _get_scale_mat( + self, + m: float, + device: torch.device, + dtype: torch.dtype, ) -> torch.Tensor: + """ + Create a scale matrix tensor. + Args: + m (float): The scale value to use. + Returns: + **scale_mat** (torch.Tensor): A scale matrix. + """ scale_mat = torch.tensor( [[m, 0.0, 0.0], [0.0, m, 0.0]], device=device, dtype=dtype ) return scale_mat - def scale_tensor( - self, x: torch.Tensor, scale: Union[int, float, torch.Tensor] - ) -> torch.Tensor: - scale_matrix = self.get_scale_mat(scale, x.device, x.dtype)[None, ...].repeat( + def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor: + """ + Scale an NCHW image tensor based on a specified scale value. + Args: + x (torch.Tensor): The NCHW image tensor to scale. + scale (float): The amount to scale the NCHW image by. + Returns: + **x** (torch.Tensor): A scaled NCHW image tensor. + """ + scale_matrix = self._get_scale_mat(scale, x.device, x.dtype)[None, ...].repeat( x.shape[0], 1, 1 ) - if torch.__version__ >= "1.3.0": + if self._has_align_corners: # Pass align_corners explicitly for torch >= 1.3.0 - grid = F.affine_grid(scale_matrix, x.size(), align_corners=False) - x = F.grid_sample(x, grid, align_corners=False) + grid = F.affine_grid( + scale_matrix, x.size(), align_corners=self.align_corners + ) + x = F.grid_sample( + x, + grid, + mode=self.mode, + padding_mode=self.padding_mode, + align_corners=self.align_corners, + ) else: grid = F.affine_grid(scale_matrix, x.size()) - x = F.grid_sample(x, grid) + x = F.grid_sample(x, grid, mode=self.mode, padding_mode=self.padding_mode) return x - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Randomly scale / zoom in or out of a tensor. - + Randomly scale an NCHW image tensor. Args: - - input (torch.Tensor): Input to randomly scale. - + x (torch.Tensor): NCHW image tensor to randomly scale. Returns: - **tensor** (torch.Tensor): Scaled *tensor*. + **x** (torch.Tensor): A randomly scaled NCHW image *tensor*. """ - scale = _rand_select(self.scale) - return self.scale_tensor(input, scale=scale) + assert x.dim() == 4 + if self._is_distribution: + scale = float(self.scale_distribution.sample().item()) + else: + n = int( + torch.randint( + low=0, + high=len(self.scale), + size=[1], + dtype=torch.int64, + layout=torch.strided, + device=x.device, + ).item() + ) + scale = self.scale[n] + return self._scale_tensor(x, scale=scale) class RandomSpatialJitter(torch.nn.Module): @@ -401,7 +583,7 @@ class RandomRotation(nn.Module): def __init__( self, - degrees: NumSeqOrTensorType, + degrees: NumSeqOrTensorOrProbDistType, mode: str = "bilinear", padding_mode: str = "zeros", align_corners: bool = False, diff --git a/captum/optim/_utils/typing.py b/captum/optim/_utils/typing.py index cf699beff1..a0e3d6f1c0 100755 --- a/captum/optim/_utils/typing.py +++ b/captum/optim/_utils/typing.py @@ -1,7 +1,7 @@ import sys from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union -from torch import Tensor +from torch import Tensor, __version__ from torch.nn import Module from torch.optim import Optimizer @@ -33,6 +33,16 @@ def cleanup(self) -> None: LossFunction = Callable[[ModuleOutputMapping], Tensor] SingleTargetLossFunction = Callable[[Tensor], Tensor] -NumSeqOrTensorType = Union[Sequence[int], Sequence[float], Tensor] +if __version__ < "1.4.0": + NumSeqOrTensorOrProbDistType = Union[Sequence[int], Sequence[float], Tensor] +else: + from torch import distributions + + NumSeqOrTensorOrProbDistType = Union[ + Sequence[int], + Sequence[float], + Tensor, + distributions.distribution.Distribution, + ] IntSeqOrIntType = Union[List[int], Tuple[int], Tuple[int, int], int] TupleOfTensorsOrTensorType = Union[Tuple[Tensor, ...], Tensor] diff --git a/tests/optim/param/test_transforms.py b/tests/optim/param/test_transforms.py index fb29c3b51d..3f048f51e4 100644 --- a/tests/optim/param/test_transforms.py +++ b/tests/optim/param/test_transforms.py @@ -15,31 +15,306 @@ from tests.optim.helpers import numpy_transforms -class TestRandSelect(BaseTest): - def test_rand_select(self) -> None: - a = (1, 2, 3, 4, 5) - b = torch.Tensor([0.1, -5, 56.7, 99.0]) +class TestRandomScale(BaseTest): + def test_random_scale_init(self) -> None: + scale_module = transforms.RandomScale(scale=[1, 0.975, 1.025, 0.95, 1.05]) + self.assertEqual(scale_module.scale, [1.0, 0.975, 1.025, 0.95, 1.05]) + self.assertFalse(scale_module._is_distribution) + self.assertEqual(scale_module.mode, "bilinear") + self.assertFalse(scale_module.align_corners) + self.assertFalse(scale_module.recompute_scale_factor) + + def test_random_scale_tensor_scale(self) -> None: + scale = torch.tensor([1, 0.975, 1.025, 0.95, 1.05]) + scale_module = transforms.RandomScale(scale=scale) + self.assertEqual(scale_module.scale, scale.tolist()) + + def test_random_scale_int_scale(self) -> None: + scale = [1, 2, 3, 4, 5] + scale_module = transforms.RandomScale(scale=scale) + for s in scale_module.scale: + self.assertIsInstance(s, float) + self.assertEqual(scale_module.scale, [1.0, 2.0, 3.0, 4.0, 5.0]) + + def test_random_scale_scale_distributions(self) -> None: + scale = torch.distributions.Uniform(0.95, 1.05) + scale_module = transforms.RandomScale(scale=scale) + self.assertIsInstance( + scale_module.scale_distribution, + torch.distributions.distribution.Distribution, + ) + self.assertTrue(scale_module._is_distribution) - self.assertIn(transforms._rand_select(a), a) - self.assertIn(transforms._rand_select(b), b) + def test_random_scale_torch_version_check(self) -> None: + scale_module = transforms.RandomScale([1.0]) + has_align_corners = torch.__version__ >= "1.3.0" + self.assertEqual(scale_module._has_align_corners, has_align_corners) -class TestRandomScale(BaseTest): - def test_random_scale(self) -> None: - scale_module = transforms.RandomScale(scale=(1, 0.975, 1.025, 0.95, 1.05)) + has_recompute_scale_factor = torch.__version__ >= "1.6.0" + self.assertEqual( + scale_module._has_recompute_scale_factor, has_recompute_scale_factor + ) + + def test_random_scale_downscaling(self) -> None: + scale_module = transforms.RandomScale(scale=[0.5]) + test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float() + + scaled_tensor = scale_module._scale_tensor(test_tensor, 0.5) + + expected_tensor = torch.tensor( + [ + [ + [ + [5.5000, 7.5000, 9.5000, 11.5000, 13.5000], + [25.5000, 27.5000, 29.5000, 31.5000, 33.5000], + [45.5000, 47.5000, 49.5000, 51.5000, 53.5000], + [65.5000, 67.5000, 69.5000, 71.5000, 73.5000], + [85.5000, 87.5000, 89.5000, 91.5000, 93.5000], + ] + ] + ] + ) + + assertTensorAlmostEqual( + self, + scaled_tensor, + expected_tensor, + 0, + ) + + def test_random_scale_upscaling(self) -> None: + scale_module = transforms.RandomScale(scale=[0.5]) + test_tensor = torch.arange(0, 1 * 1 * 2 * 2).view(1, 1, 2, 2).float() + + scaled_tensor = scale_module._scale_tensor(test_tensor, 1.5) + + expected_tensor = torch.tensor( + [ + [ + [ + [0.0000, 0.5000, 1.0000], + [1.0000, 1.5000, 2.0000], + [2.0000, 2.5000, 3.0000], + ] + ] + ] + ) + + assertTensorAlmostEqual( + self, + scaled_tensor, + expected_tensor, + 0, + ) + + def test_random_forward_exact(self) -> None: + scale_module = transforms.RandomScale(scale=[0.5]) + test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float() + + scaled_tensor = scale_module(test_tensor) + + expected_tensor = torch.tensor( + [ + [ + [ + [5.5000, 7.5000, 9.5000, 11.5000, 13.5000], + [25.5000, 27.5000, 29.5000, 31.5000, 33.5000], + [45.5000, 47.5000, 49.5000, 51.5000, 53.5000], + [65.5000, 67.5000, 69.5000, 71.5000, 73.5000], + [85.5000, 87.5000, 89.5000, 91.5000, 93.5000], + ] + ] + ] + ) + + assertTensorAlmostEqual( + self, + scaled_tensor, + expected_tensor, + 0, + ) + + def test_random_scale_forward_exact_nearest(self) -> None: + scale_module = transforms.RandomScale(scale=[0.5], mode="nearest") + self.assertIsNone(scale_module.align_corners) + self.assertEqual(scale_module.mode, "nearest") + + test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float() + + scaled_tensor = scale_module(test_tensor) + + expected_tensor = torch.tensor( + [ + [ + [ + [0.0, 2.0, 4.0, 6.0, 8.0], + [20.0, 22.0, 24.0, 26.0, 28.0], + [40.0, 42.0, 44.0, 46.0, 48.0], + [60.0, 62.0, 64.0, 66.0, 68.0], + [80.0, 82.0, 84.0, 86.0, 88.0], + ] + ] + ] + ) + + assertTensorAlmostEqual( + self, + scaled_tensor, + expected_tensor, + 0, + ) + + def test_random_scale_forward_exact_align_corners(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping RandomScale exact align corners forward due to" + + " insufficient Torch version." + ) + scale_module = transforms.RandomScale(scale=[0.5], align_corners=True) + self.assertTrue(scale_module.align_corners) + + test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float() + + scaled_tensor = scale_module(test_tensor) + + expected_tensor = torch.tensor( + [ + [ + [ + [0.0000, 2.2500, 4.5000, 6.7500, 9.0000], + [22.5000, 24.7500, 27.0000, 29.2500, 31.5000], + [45.0000, 47.2500, 49.5000, 51.7500, 54.0000], + [67.5000, 69.7500, 72.0000, 74.2500, 76.5000], + [90.0000, 92.2500, 94.5000, 96.7500, 99.0000], + ] + ] + ] + ) + assertTensorAlmostEqual( + self, + scaled_tensor, + expected_tensor, + 0, + ) + + def test_random_scale_forward(self) -> None: + scale_module = transforms.RandomScale(scale=[0.5]) + test_tensor = torch.ones(1, 3, 10, 10) + output_tensor = scale_module(test_tensor) + self.assertEqual(list(output_tensor.shape), [1, 3, 5, 5]) + + def test_random_scale_forward_distributions(self) -> None: + scale = torch.distributions.Uniform(0.95, 1.05) + scale_module = transforms.RandomScale(scale=scale) + test_tensor = torch.ones(1, 3, 10, 10) + output_tensor = scale_module(test_tensor) + self.assertTrue(torch.is_tensor(output_tensor)) + + def test_random_scale_jit_module(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping RandomScale JIT module test due to insufficient" + + " Torch version." + ) + scale_module = transforms.RandomScale(scale=[1.5]) + jit_scale_module = torch.jit.script(scale_module) + + test_tensor = torch.arange(0, 1 * 1 * 2 * 2).view(1, 1, 2, 2).float() + scaled_tensor = jit_scale_module(test_tensor) + + expected_tensor = torch.tensor( + [ + [ + [ + [0.0000, 0.5000, 1.0000], + [1.0000, 1.5000, 2.0000], + [2.0000, 2.5000, 3.0000], + ] + ] + ] + ) + + assertTensorAlmostEqual( + self, + scaled_tensor, + expected_tensor, + 0, + ) + + +class TestRandomScaleAffine(BaseTest): + def test_random_scale_affine_init(self) -> None: + scale_module = transforms.RandomScaleAffine(scale=[1, 0.975, 1.025, 0.95, 1.05]) + self.assertEqual(scale_module.scale, [1.0, 0.975, 1.025, 0.95, 1.05]) + self.assertFalse(scale_module._is_distribution) + self.assertEqual(scale_module.mode, "bilinear") + self.assertEqual(scale_module.padding_mode, "zeros") + self.assertFalse(scale_module.align_corners) + + def test_random_scale_affine_tensor_scale(self) -> None: + scale = torch.tensor([1, 0.975, 1.025, 0.95, 1.05]) + scale_module = transforms.RandomScaleAffine(scale=scale) + self.assertEqual(scale_module.scale, scale.tolist()) + + def test_random_scale_affine_int_scale(self) -> None: + scale = [1, 2, 3, 4, 5] + scale_module = transforms.RandomScaleAffine(scale=scale) + for s in scale_module.scale: + self.assertIsInstance(s, float) + self.assertEqual(scale_module.scale, [1.0, 2.0, 3.0, 4.0, 5.0]) + + def test_random_scale_affine_scale_distributions(self) -> None: + scale = torch.distributions.Uniform(0.95, 1.05) + scale_module = transforms.RandomScaleAffine(scale=scale) + self.assertIsInstance( + scale_module.scale_distribution, + torch.distributions.distribution.Distribution, + ) + self.assertTrue(scale_module._is_distribution) + + def test_random_scale_affine_torch_version_check(self) -> None: + scale_module = transforms.RandomScaleAffine([1.0]) + _has_align_corners = torch.__version__ >= "1.3.0" + self.assertEqual(scale_module._has_align_corners, _has_align_corners) + + def test_random_scale_affine_matrix(self) -> None: + scale_module = transforms.RandomScaleAffine(scale=[0.5]) + test_tensor = torch.ones(1, 3, 3, 3) + # Test scale matrices + + assertTensorAlmostEqual( + self, + scale_module._get_scale_mat(0.5, test_tensor.device, test_tensor.dtype), + torch.tensor([[0.5000, 0.0000, 0.0000], [0.0000, 0.5000, 0.0000]]), + 0, + ) + + assertTensorAlmostEqual( + self, + scale_module._get_scale_mat(1.24, test_tensor.device, test_tensor.dtype), + torch.tensor([[1.2400, 0.0000, 0.0000], [0.0000, 1.2400, 0.0000]]), + 0, + ) + + def test_random_scale_affine_downscaling(self) -> None: + scale_module = transforms.RandomScaleAffine(scale=[0.5]) test_tensor = torch.ones(1, 3, 3, 3) - # Test rescaling assertTensorAlmostEqual( self, - scale_module.scale_tensor(test_tensor, 0.5), + scale_module._scale_tensor(test_tensor, 0.5), torch.ones(3, 1).repeat(3, 1, 3).unsqueeze(0), 0, ) + def test_random_scale_affine_upscaling(self) -> None: + scale_module = transforms.RandomScaleAffine(scale=[1.5]) + test_tensor = torch.ones(1, 3, 3, 3) + assertTensorAlmostEqual( self, - scale_module.scale_tensor(test_tensor, 1.5), + scale_module._scale_tensor(test_tensor, 1.5), torch.tensor( [ [0.2500, 0.5000, 0.2500], @@ -52,22 +327,92 @@ def test_random_scale(self) -> None: 0, ) - def test_random_scale_matrix(self) -> None: - scale_module = transforms.RandomScale(scale=(1, 0.975, 1.025, 0.95, 1.05)) - test_tensor = torch.ones(1, 3, 3, 3) - # Test scale matrices + def test_random_scale_affine_forward_exact(self) -> None: + scale_module = transforms.RandomScaleAffine(scale=[1.5]) + test_tensor = torch.arange(0, 1 * 1 * 4 * 4).view(1, 1, 4, 4).float() + scaled_tensor = scale_module(test_tensor) + + expected_tensor = torch.tensor( + [ + [ + [ + [0.0000, 0.1875, 0.5625, 0.1875], + [0.7500, 3.7500, 5.2500, 1.5000], + [2.2500, 9.7500, 11.2500, 3.0000], + [0.7500, 3.1875, 3.5625, 0.9375], + ] + ] + ] + ) assertTensorAlmostEqual( self, - scale_module.get_scale_mat(0.5, test_tensor.device, test_tensor.dtype), - torch.tensor([[0.5000, 0.0000, 0.0000], [0.0000, 0.5000, 0.0000]]), + scaled_tensor, + expected_tensor, 0, ) + def test_random_scale_affine_forward_exact_mode_nearest(self) -> None: + scale_module = transforms.RandomScaleAffine(scale=[1.5], mode="nearest") + self.assertEqual(scale_module.mode, "nearest") + test_tensor = torch.arange(0, 1 * 1 * 4 * 4).view(1, 1, 4, 4).float() + + scaled_tensor = scale_module(test_tensor) + expected_tensor = torch.tensor( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 5.0, 6.0, 0.0], + [0.0, 9.0, 10.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ] + ] + ] + ) + assertTensorAlmostEqual( self, - scale_module.get_scale_mat(1.24, test_tensor.device, test_tensor.dtype), - torch.tensor([[1.2400, 0.0000, 0.0000], [0.0000, 1.2400, 0.0000]]), + scaled_tensor, + expected_tensor, + 0, + ) + + def test_random_scale_affine_forward(self) -> None: + scale_module = transforms.RandomScaleAffine(scale=[0.5]) + test_tensor = torch.ones(1, 3, 10, 10) + output_tensor = scale_module(test_tensor) + self.assertEqual(list(output_tensor.shape), list(test_tensor.shape)) + + def test_random_scale_affine_forward_distributions(self) -> None: + scale = torch.distributions.Uniform(0.95, 1.05) + scale_module = transforms.RandomScaleAffine(scale=scale) + test_tensor = torch.ones(1, 3, 10, 10) + output_tensor = scale_module(test_tensor) + self.assertEqual(list(output_tensor.shape), list(test_tensor.shape)) + + def test_random_scale_affine_jit_module(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping RandomScaleAffine JIT module test due to insufficient" + + " Torch version." + ) + scale_module = transforms.RandomScaleAffine(scale=[1.5]) + jit_scale_module = torch.jit.script(scale_module) + test_tensor = torch.ones(1, 3, 3, 3) + + assertTensorAlmostEqual( + self, + jit_scale_module(test_tensor), + torch.tensor( + [ + [0.2500, 0.5000, 0.2500], + [0.5000, 1.0000, 0.5000], + [0.2500, 0.5000, 0.2500], + ] + ) + .repeat(3, 1, 1) + .unsqueeze(0), 0, ) From 70357ec5e6ad288c1ecd81f2d6c4fce186779ba3 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 9 Apr 2022 10:01:13 -0600 Subject: [PATCH 2/4] Add `RandomScaleAffine` to transforms.py `__all__` --- captum/optim/_param/image/transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index c770f3456b..7be46cdb13 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -1000,6 +1000,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: "CenterCrop", "center_crop", "RandomScale", + "RandomScaleAffine", "RandomSpatialJitter", "RandomRotation", "ScaleInputRange", From d88b4dac9c9481b63c5abdc4c981bb3d22232c3b Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 13 Apr 2022 09:11:01 -0600 Subject: [PATCH 3/4] Add `F.interpolate`'s `antialias` parameter to `RandomScale` --- captum/optim/_param/image/transforms.py | 50 +++++++++++++++++++++---- tests/optim/param/test_transforms.py | 36 ++++++++++++++++++ 2 files changed, 79 insertions(+), 7 deletions(-) diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index 7be46cdb13..9b6d7c34a5 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -285,6 +285,8 @@ class RandomScale(nn.Module): "_has_align_corners", "recompute_scale_factor", "_has_recompute_scale_factor", + "antialias", + "_has_antialias", "_is_distribution", ] @@ -294,9 +296,11 @@ def __init__( mode: str = "bilinear", align_corners: Optional[bool] = False, recompute_scale_factor: bool = False, + antialias: bool = False, ) -> None: """ Args: + scale (float, sequence, or torch.distribution): Sequence of rescaling values to randomly select from, or a torch.distributions instance. mode (str, optional): Interpolation mode to use. See documentation of @@ -309,6 +313,10 @@ def __init__( recompute_scale_factor (bool, optional): Whether or not to recompute the scale factor See documentation of F.interpolate for more details. Default: False + antialias (bool, optional): Whether or not use to anti-aliasing. This + feature is currently only available for "bilinear" and "bicubic" + modes. See documentation of F.interpolate for more details. + Default: False """ super().__init__() assert mode not in ["linear", "trilinear"] @@ -329,27 +337,42 @@ def __init__( self.mode = mode self.align_corners = align_corners if mode not in ["nearest", "area"] else None self.recompute_scale_factor = recompute_scale_factor + self.antialias = antialias self._has_align_corners = torch.__version__ >= "1.3.0" self._has_recompute_scale_factor = torch.__version__ >= "1.6.0" + self._has_antialias = torch.__version__ >= "1.11.0" def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor: """ Scale an NCHW image tensor based on a specified scale value. + Args: + x (torch.Tensor): The NCHW image tensor to scale. scale (float): The amount to scale the NCHW image by. + Returns: **x** (torch.Tensor): A scaled NCHW image tensor. """ if self._has_align_corners: if self._has_recompute_scale_factor: - x = F.interpolate( - x, - scale_factor=scale, - mode=self.mode, - align_corners=self.align_corners, - recompute_scale_factor=self.recompute_scale_factor, - ) + if self._has_antialias: + x = F.interpolate( + x, + scale_factor=scale, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + antialias=self.antialias, + ) + else: + x = F.interpolate( + x, + scale_factor=scale, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + ) else: x = F.interpolate( x, @@ -364,8 +387,11 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """ Randomly scale an NCHW image tensor. + Args: + x (torch.Tensor): NCHW image tensor to randomly scale. + Returns: **x** (torch.Tensor): A randomly scaled NCHW image *tensor*. """ @@ -417,6 +443,7 @@ def __init__( ) -> None: """ Args: + scale (float, sequence, or torch.distribution): Sequence of rescaling values to randomly select from, or a torch.distributions instance. mode (str, optional): Interpolation mode to use. See documentation of @@ -459,8 +486,11 @@ def _get_scale_mat( ) -> torch.Tensor: """ Create a scale matrix tensor. + Args: + m (float): The scale value to use. + Returns: **scale_mat** (torch.Tensor): A scale matrix. """ @@ -472,9 +502,12 @@ def _get_scale_mat( def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor: """ Scale an NCHW image tensor based on a specified scale value. + Args: + x (torch.Tensor): The NCHW image tensor to scale. scale (float): The amount to scale the NCHW image by. + Returns: **x** (torch.Tensor): A scaled NCHW image tensor. """ @@ -501,8 +534,11 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """ Randomly scale an NCHW image tensor. + Args: + x (torch.Tensor): NCHW image tensor to randomly scale. + Returns: **x** (torch.Tensor): A randomly scaled NCHW image *tensor*. """ diff --git a/tests/optim/param/test_transforms.py b/tests/optim/param/test_transforms.py index 3f048f51e4..b0c5b7bb48 100644 --- a/tests/optim/param/test_transforms.py +++ b/tests/optim/param/test_transforms.py @@ -23,6 +23,7 @@ def test_random_scale_init(self) -> None: self.assertEqual(scale_module.mode, "bilinear") self.assertFalse(scale_module.align_corners) self.assertFalse(scale_module.recompute_scale_factor) + self.assertFalse(scale_module.antialias) def test_random_scale_tensor_scale(self) -> None: scale = torch.tensor([1, 0.975, 1.025, 0.95, 1.05]) @@ -56,6 +57,9 @@ def test_random_scale_torch_version_check(self) -> None: scale_module._has_recompute_scale_factor, has_recompute_scale_factor ) + has_antialias = torch.__version__ >= "1.11.0" + self.assertEqual(scale_module._has_antialias, has_antialias) + def test_random_scale_downscaling(self) -> None: scale_module = transforms.RandomScale(scale=[0.5]) test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float() @@ -108,6 +112,38 @@ def test_random_scale_upscaling(self) -> None: 0, ) + def test_random_scale_antialias(self) -> None: + if torch.__version__ < "1.11.0": + raise unittest.SkipTest( + "Skipping RandomScale antialias test" + + " due to insufficient Torch version." + ) + scale_module = transforms.RandomScale(scale=[0.5], antialias=True) + test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float() + + scaled_tensor = scale_module._scale_tensor(test_tensor, 0.5) + + expected_tensor = torch.tensor( + [ + [ + [ + [7.8571, 9.6429, 11.6429, 13.6429, 15.4286], + [25.7143, 27.5000, 29.5000, 31.5000, 33.2857], + [45.7143, 47.5000, 49.5000, 51.5000, 53.2857], + [65.7143, 67.5000, 69.5000, 71.5000, 73.2857], + [83.5714, 85.3571, 87.3571, 89.3571, 91.1429], + ] + ] + ] + ) + + assertTensorAlmostEqual( + self, + scaled_tensor, + expected_tensor, + 0.0005, + ) + def test_random_forward_exact(self) -> None: scale_module = transforms.RandomScale(scale=[0.5]) test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float() From 3ad19752a55e12e2b196d6b33e793c072e0f952c Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 9 May 2022 12:53:23 -0600 Subject: [PATCH 4/4] Fix spacing --- captum/optim/_param/image/transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index b50c290a91..b81b4bfd22 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -718,7 +718,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self._scale_tensor(x, scale=scale) - class RandomSpatialJitter(torch.nn.Module): """ Apply random spatial translations on a NCHW tensor.