Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions captum/optim/_param/image/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ class RandomScale(nn.Module):
"_has_align_corners",
"recompute_scale_factor",
"_has_recompute_scale_factor",
"antialias",
"_has_antialias",
"_is_distribution",
]

Expand All @@ -452,6 +454,7 @@ def __init__(
mode: str = "bilinear",
align_corners: Optional[bool] = False,
recompute_scale_factor: bool = False,
antialias: bool = False,
) -> None:
"""
Args:
Expand All @@ -468,6 +471,10 @@ def __init__(
recompute_scale_factor (bool, optional): Whether or not to recompute the
scale factor See documentation of F.interpolate for more details.
Default: False
antialias (bool, optional): Whether or not use to anti-aliasing. This
feature is currently only available for "bilinear" and "bicubic"
modes. See documentation of F.interpolate for more details.
Default: False
"""
super().__init__()
assert mode not in ["linear", "trilinear"]
Expand All @@ -488,8 +495,10 @@ def __init__(
self.mode = mode
self.align_corners = align_corners if mode not in ["nearest", "area"] else None
self.recompute_scale_factor = recompute_scale_factor
self.antialias = antialias
self._has_align_corners = torch.__version__ >= "1.3.0"
self._has_recompute_scale_factor = torch.__version__ >= "1.6.0"
self._has_antialias = torch.__version__ >= "1.11.0"

def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
"""
Expand All @@ -505,13 +514,23 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
"""
if self._has_align_corners:
if self._has_recompute_scale_factor:
x = F.interpolate(
x,
scale_factor=scale,
mode=self.mode,
align_corners=self.align_corners,
recompute_scale_factor=self.recompute_scale_factor,
)
if self._has_antialias:
x = F.interpolate(
x,
scale_factor=scale,
mode=self.mode,
align_corners=self.align_corners,
recompute_scale_factor=self.recompute_scale_factor,
antialias=self.antialias,
)
else:
x = F.interpolate(
x,
scale_factor=scale,
mode=self.mode,
align_corners=self.align_corners,
recompute_scale_factor=self.recompute_scale_factor,
)
else:
x = F.interpolate(
x,
Expand Down Expand Up @@ -555,13 +574,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class RandomScaleAffine(nn.Module):
"""
Apply random rescaling on a NCHW tensor.

This random scaling transform utilizes F.affine_grid & F.grid_sample, and as a
result has two key differences to the default RandomScale transforms This
transform either shrinks an image while adding a background, or center crops image
and then resizes it to a larger size. This means that the output image shape is the
same shape as the input image.

In constrast to RandomScaleAffine, the default RandomScale transform simply resizes
the input image using F.interpolate.
"""
Expand Down Expand Up @@ -1328,6 +1345,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"CenterCrop",
"center_crop",
"RandomScale",
"RandomScaleAffine",
"RandomSpatialJitter",
"RandomRotation",
"ScaleInputRange",
Expand Down
36 changes: 36 additions & 0 deletions tests/optim/param/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_random_scale_init(self) -> None:
self.assertEqual(scale_module.mode, "bilinear")
self.assertFalse(scale_module.align_corners)
self.assertFalse(scale_module.recompute_scale_factor)
self.assertFalse(scale_module.antialias)

def test_random_scale_tensor_scale(self) -> None:
scale = torch.tensor([1, 0.975, 1.025, 0.95, 1.05])
Expand Down Expand Up @@ -56,6 +57,9 @@ def test_random_scale_torch_version_check(self) -> None:
scale_module._has_recompute_scale_factor, has_recompute_scale_factor
)

has_antialias = torch.__version__ >= "1.11.0"
self.assertEqual(scale_module._has_antialias, has_antialias)

def test_random_scale_downscaling(self) -> None:
scale_module = transforms.RandomScale(scale=[0.5])
test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float()
Expand Down Expand Up @@ -108,6 +112,38 @@ def test_random_scale_upscaling(self) -> None:
0,
)

def test_random_scale_antialias(self) -> None:
if torch.__version__ < "1.11.0":
raise unittest.SkipTest(
"Skipping RandomScale antialias test"
+ " due to insufficient Torch version."
)
scale_module = transforms.RandomScale(scale=[0.5], antialias=True)
test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float()

scaled_tensor = scale_module._scale_tensor(test_tensor, 0.5)

expected_tensor = torch.tensor(
[
[
[
[7.8571, 9.6429, 11.6429, 13.6429, 15.4286],
[25.7143, 27.5000, 29.5000, 31.5000, 33.2857],
[45.7143, 47.5000, 49.5000, 51.5000, 53.2857],
[65.7143, 67.5000, 69.5000, 71.5000, 73.2857],
[83.5714, 85.3571, 87.3571, 89.3571, 91.1429],
]
]
]
)

assertTensorAlmostEqual(
self,
scaled_tensor,
expected_tensor,
0.0005,
)

def test_random_forward_exact(self) -> None:
scale_module = transforms.RandomScale(scale=[0.5])
test_tensor = torch.arange(0, 1 * 1 * 10 * 10).view(1, 1, 10, 10).float()
Expand Down