diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index de642fc47c..b81b4bfd22 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -443,6 +443,8 @@ class RandomScale(nn.Module): "_has_align_corners", "recompute_scale_factor", "_has_recompute_scale_factor", + "antialias", + "_has_antialias", "_is_distribution", ] @@ -452,6 +454,7 @@ def __init__( mode: str = "bilinear", align_corners: Optional[bool] = False, recompute_scale_factor: bool = False, + antialias: bool = False, ) -> None: """ Args: @@ -468,6 +471,10 @@ def __init__( recompute_scale_factor (bool, optional): Whether or not to recompute the scale factor See documentation of F.interpolate for more details. Default: False + antialias (bool, optional): Whether or not use to anti-aliasing. This + feature is currently only available for "bilinear" and "bicubic" + modes. See documentation of F.interpolate for more details. + Default: False """ super().__init__() assert mode not in ["linear", "trilinear"] @@ -488,8 +495,10 @@ def __init__( self.mode = mode self.align_corners = align_corners if mode not in ["nearest", "area"] else None self.recompute_scale_factor = recompute_scale_factor + self.antialias = antialias self._has_align_corners = torch.__version__ >= "1.3.0" self._has_recompute_scale_factor = torch.__version__ >= "1.6.0" + self._has_antialias = torch.__version__ >= "1.11.0" def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor: """ @@ -505,13 +514,23 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor: """ if self._has_align_corners: if self._has_recompute_scale_factor: - x = F.interpolate( - x, - scale_factor=scale, - mode=self.mode, - align_corners=self.align_corners, - recompute_scale_factor=self.recompute_scale_factor, - ) + if self._has_antialias: + x = F.interpolate( + x, + scale_factor=scale, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + antialias=self.antialias, + ) + else: + x = F.interpolate( + x, + scale_factor=scale, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + ) else: x = F.interpolate( x, @@ -555,13 +574,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class RandomScaleAffine(nn.Module): """ Apply random rescaling on a NCHW tensor. - This random scaling transform utilizes F.affine_grid & F.grid_sample, and as a result has two key differences to the default RandomScale transforms This transform either shrinks an image while adding a background, or center crops image and then resizes it to a larger size. This means that the output image shape is the same shape as the input image. - In constrast to RandomScaleAffine, the default RandomScale transform simply resizes the input image using F.interpolate. """ @@ -1328,6 +1345,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: "CenterCrop", "center_crop", "RandomScale", + "RandomScaleAffine", "RandomSpatialJitter", "RandomRotation", "ScaleInputRange", diff --git a/tests/optim/param/test_transforms.py b/tests/optim/param/test_transforms.py index 516fe4c5da..6e07277c61 100644 --- a/tests/optim/param/test_transforms.py +++ b/tests/optim/param/test_transforms.py @@ -23,6 +23,7 @@ def test_random_scale_init(self) -> None: self.assertEqual(scale_module.mode, "bilinear") self.assertFalse(scale_module.align_corners) self.assertFalse(scale_module.recompute_scale_factor) + self.assertFalse(scale_module.antialias) def test_random_scale_tensor_scale(self) -> None: scale = torch.tensor([1, 0.975, 1.025, 0.95, 1.05]) @@ -56,6 +57,9 @@ def test_random_scale_torch_version_check(self) -> None: scale_module._has_recompute_scale_factor, has_recompute_scale_factor ) + has_antialias = torch.__version__ >= "1.11.0" + self.assertEqual(scale_module._has_antialias, has_antialias) + def test_random_scale_downscaling(self) -> None: scale_module = transforms.RandomScale(scale=[0.5]) test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float() @@ -108,6 +112,38 @@ def test_random_scale_upscaling(self) -> None: 0, ) + def test_random_scale_antialias(self) -> None: + if torch.__version__ < "1.11.0": + raise unittest.SkipTest( + "Skipping RandomScale antialias test" + + " due to insufficient Torch version." + ) + scale_module = transforms.RandomScale(scale=[0.5], antialias=True) + test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float() + + scaled_tensor = scale_module._scale_tensor(test_tensor, 0.5) + + expected_tensor = torch.tensor( + [ + [ + [ + [7.8571, 9.6429, 11.6429, 13.6429, 15.4286], + [25.7143, 27.5000, 29.5000, 31.5000, 33.2857], + [45.7143, 47.5000, 49.5000, 51.5000, 53.2857], + [65.7143, 67.5000, 69.5000, 71.5000, 73.2857], + [83.5714, 85.3571, 87.3571, 89.3571, 91.1429], + ] + ] + ] + ) + + assertTensorAlmostEqual( + self, + scaled_tensor, + expected_tensor, + 0.0005, + ) + def test_random_forward_exact(self) -> None: scale_module = transforms.RandomScale(scale=[0.5]) test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float()