From 49694ba878381be6d5a0bf2d39619d2a250784c8 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Fri, 31 Dec 2021 07:34:18 -0700 Subject: [PATCH 1/8] Add new StackImage parameterization & JIT support for SharedImage * Added `SimpleTensorParameterization` as a workaround for JIT not supporting `nn.ParameterList`. It also helps `StackImage` support tensor inputs. * Added JIT support for `SharedImage`. * Added new parameterization called `StackImage`, that stacks multiple parameterizations (that are can be on different devices) along the batch dimension. --- captum/optim/_param/image/images.py | 117 ++++++++++- tests/optim/param/test_images.py | 290 +++++++++++++++++++++++++++- 2 files changed, 393 insertions(+), 14 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index cf4b01da0d..faeef9c81d 100644 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -407,6 +407,34 @@ def forward(self) -> torch.Tensor: return torch.stack(A).refine_names("B", "C", "H", "W") +class SimpleTensorParameterization(ImageParameterization): + """ + Parameterize a simple tensor with or without it requiring grad. + Compared to PixelImage, this parameterization has no specific shape requirements + and does not wrap inputs in nn.Parameter. + + This parameterization can for example be combined with StackImage for batch + dimensions that both require and don't require gradients. + + This parameterization can also be combined with nn.ModuleList as workaround for + TorchScript / JIT not supporting nn.ParameterList. SharedImage uses this module + internally for this purpose. + """ + + def __init__(self, tensor: torch.Tensor = None) -> None: + """ + Args: + + tensor (torch.tensor): The tensor to return everytime this module is called. + """ + super().__init__() + assert isinstance(tensor, torch.Tensor) + self.tensor = tensor + + def forward(self) -> torch.Tensor: + return self.tensor + + class SharedImage(ImageParameterization): """ Share some image parameters across the batch to increase spatial alignment, @@ -420,6 +448,8 @@ class SharedImage(ImageParameterization): https://distill.pub/2018/differentiable-parameterizations/ """ + __constants__ = ["offset", "_supports_is_scripting"] + def __init__( self, shapes: Union[Tuple[Tuple[int]], Tuple[int]] = None, @@ -445,11 +475,17 @@ def __init__( assert len(shape) >= 2 and len(shape) <= 4 shape = ([1] * (4 - len(shape))) + list(shape) batch, channels, height, width = shape - A.append(torch.nn.Parameter(torch.randn([batch, channels, height, width]))) - self.shared_init = torch.nn.ParameterList(A) + shape_param = torch.nn.Parameter( + torch.randn([batch, channels, height, width]) + ) + A.append(SimpleTensorParameterization(shape_param)) + self.shared_init = torch.nn.ModuleList(A) self.parameterization = parameterization self.offset = self._get_offset(offset, len(A)) if offset is not None else None + # Check & store whether or not we can use torch.jit.is_scripting() + self._supports_is_scripting = torch.__version__ >= "1.6.0" + def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]: """ Given offset values, return a list of offsets for _apply_offset to use. @@ -475,6 +511,7 @@ def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]] assert all([all([type(o) is int for o in v]) for v in offset]) return offset + @torch.jit.ignore def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: """ Apply list of offsets to list of tensors. @@ -508,6 +545,7 @@ def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: A.append(x) return A + @torch.jit.ignore def _interpolate_tensor( self, x: torch.Tensor, batch: int, channels: int, height: int, width: int ) -> torch.Tensor: @@ -550,7 +588,7 @@ def forward(self) -> torch.Tensor: image = self.parameterization() x = [ self._interpolate_tensor( - shared_tensor, + shared_tensor(), image.size(0), image.size(1), image.size(2), @@ -560,7 +598,78 @@ def forward(self) -> torch.Tensor: ] if self.offset is not None: x = self._apply_offset(x) - return (image + sum(x)).refine_names("B", "C", "H", "W") + output = image + torch.cat(x, 0).sum(0, keepdim=True) + + if self._supports_is_scripting: + if torch.jit.is_scripting(): + return output + return output.refine_names("B", "C", "H", "W") + + +class StackImage(ImageParameterization): + """ + Stack multiple NCHW image parameterizations along their batch dimensions. + """ + + __constants__ = ["_supports_is_scripting", "output_device"] + + def __init__( + self, + parameterizations: List[Union[ImageParameterization, torch.Tensor]], + output_device: Optional[torch.device] = None, + ) -> None: + """ + Args: + + parameterizations (list of ImageParameterization and torch.Tensor): A list + of image parameterizations to stack across their batch dimensions. + output_device (torch.device): If the parameterizations are on different + devices, then their outputs will be moved to the device specified by + this variable. Default is set to None with the expectation that all + parameterizations are on the same device. + Default: None + """ + super().__init__() + assert len(parameterizations) > 0 + assert isinstance(parameterizations, (list, tuple)) + assert all( + [ + isinstance(param, (ImageParameterization, torch.Tensor)) + for param in parameterizations + ] + ) + parameterizations = [ + SimpleTensorParameterization(p) if isinstance(p, torch.Tensor) else p + for p in parameterizations + ] + self.parameterizations = torch.nn.ModuleList(parameterizations) + self.output_device = output_device + + # Check & store whether or not we can use torch.jit.is_scripting() + self._supports_is_scripting = torch.__version__ >= "1.6.0" + + def forward(self) -> torch.Tensor: + """ + Returns: + image (torch.Tensor): A set of NCHW image parameterization outputs stacked + along the batch dimension. + """ + P = [] + for image_param in self.parameterizations: + img = image_param() + if self.output_device is not None: + img = img.to(self.output_device, dtype=img.dtype) + P.append(img) + + assert P[0].dim() == 4 + assert all([im.shape == P[0].shape for im in P]) + assert all([im.device == P[0].device for im in P]) + + image = torch.cat(P, 0) + if self._supports_is_scripting: + if torch.jit.is_scripting(): + return image + return image.refine_names("B", "C", "H", "W") class NaturalImage(ImageParameterization): diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 7c420aa579..f0fc045934 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -325,6 +325,63 @@ def test_laplacianimage_init(self) -> None: assertArraysAlmostEqual(np.ones_like(test_np) * 0.5, test_np) +class TestSimpleTensorParameterization(BaseTest): + def test_simple_tensor_parameterization_no_grad(self) -> None: + test_input = torch.randn(1, 3, 4, 4) + image_param = images.SimpleTensorParameterization(test_input) + assertTensorAlmostEqual(self, image_param.tensor, test_input, 0.0) + self.assertFalse(image_param.tensor.requires_grad) + + test_output = image_param() + assertTensorAlmostEqual(self, test_output, test_input, 0.0) + self.assertFalse(image_param.tensor.requires_grad) + + def test_simple_tensor_parameterization_jit_module_no_grad(self) -> None: + test_input = torch.randn(1, 3, 4, 4) + image_param = images.SimpleTensorParameterization(test_input) + jit_image_param = torch.jit.script(image_param) + + test_output = jit_image_param() + assertTensorAlmostEqual(self, test_output, test_input, 0.0) + self.assertFalse(image_param.tensor.requires_grad) + + def test_simple_tensor_parameterization_with_grad(self) -> None: + test_input = torch.nn.Parameter(torch.randn(1, 3, 4, 4)) + image_param = images.SimpleTensorParameterization(test_input) + assertTensorAlmostEqual(self, image_param.tensor, test_input, 0.0) + self.assertTrue(image_param.tensor.requires_grad) + + test_output = image_param() + assertTensorAlmostEqual(self, test_output, test_input, 0.0) + self.assertTrue(image_param.tensor.requires_grad) + + def test_simple_tensor_parameterization_jit_module_with_grad(self) -> None: + test_input = torch.nn.Parameter(torch.randn(1, 3, 4, 4)) + image_param = images.SimpleTensorParameterization(test_input) + jit_image_param = torch.jit.script(image_param) + + test_output = jit_image_param() + assertTensorAlmostEqual(self, test_output, test_input, 0.0) + self.assertTrue(image_param.tensor.requires_grad) + + def test_simple_tensor_parameterization_cuda(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping SimpleTensorParameterization CUDA test due to not supporting" + + " CUDA." + ) + test_input = torch.randn(1, 3, 4, 4).cuda() + image_param = images.SimpleTensorParameterization(test_input) + self.assertTrue(image_param.tensor.is_cuda) + assertTensorAlmostEqual(self, image_param.tensor, test_input, 0.0) + self.assertFalse(image_param.tensor.requires_grad) + + test_output = image_param() + self.assertTrue(test_output.is_cuda) + assertTensorAlmostEqual(self, test_output, test_input, 0.0) + self.assertFalse(image_param.tensor.requires_grad) + + class TestSharedImage(BaseTest): def test_sharedimage_get_offset_single_number(self) -> None: if torch.__version__ <= "1.2.0": @@ -502,9 +559,9 @@ def test_sharedimage_single_shape_hw_forward(self) -> None: test_tensor = image_param.forward() self.assertIsNone(image_param.offset) - self.assertEqual(image_param.shared_init[0].dim(), 4) + self.assertEqual(image_param.shared_init[0]().dim(), 4) self.assertEqual( - list(image_param.shared_init[0].shape), [1, 1] + list(shared_shapes) + list(image_param.shared_init[0]().shape), [1, 1] + list(shared_shapes) ) self.assertEqual(test_tensor.dim(), 4) self.assertEqual(test_tensor.size(0), batch) @@ -529,9 +586,9 @@ def test_sharedimage_single_shape_chw_forward(self) -> None: test_tensor = image_param.forward() self.assertIsNone(image_param.offset) - self.assertEqual(image_param.shared_init[0].dim(), 4) + self.assertEqual(image_param.shared_init[0]().dim(), 4) self.assertEqual( - list(image_param.shared_init[0].shape), [1] + list(shared_shapes) + list(image_param.shared_init[0]().shape), [1] + list(shared_shapes) ) self.assertEqual(test_tensor.dim(), 4) self.assertEqual(test_tensor.size(0), batch) @@ -556,8 +613,8 @@ def test_sharedimage_single_shape_bchw_forward(self) -> None: test_tensor = image_param.forward() self.assertIsNone(image_param.offset) - self.assertEqual(image_param.shared_init[0].dim(), 4) - self.assertEqual(list(image_param.shared_init[0].shape), list(shared_shapes)) + self.assertEqual(image_param.shared_init[0]().dim(), 4) + self.assertEqual(list(image_param.shared_init[0]().shape), list(shared_shapes)) self.assertEqual(test_tensor.dim(), 4) self.assertEqual(test_tensor.size(0), batch) self.assertEqual(test_tensor.size(1), channels) @@ -589,9 +646,9 @@ def test_sharedimage_multiple_shapes_forward(self) -> None: self.assertIsNone(image_param.offset) for i in range(len(shared_shapes)): - self.assertEqual(image_param.shared_init[i].dim(), 4) + self.assertEqual(image_param.shared_init[i]().dim(), 4) self.assertEqual( - list(image_param.shared_init[i].shape), list(shared_shapes[i]) + list(image_param.shared_init[i]().shape), list(shared_shapes[i]) ) self.assertEqual(test_tensor.dim(), 4) self.assertEqual(test_tensor.size(0), batch) @@ -624,10 +681,10 @@ def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None: self.assertIsNone(image_param.offset) for i in range(len(shared_shapes)): - self.assertEqual(image_param.shared_init[i].dim(), 4) + self.assertEqual(image_param.shared_init[i]().dim(), 4) s_shape = list(shared_shapes[i]) s_shape = ([1] * (4 - len(s_shape))) + list(s_shape) - self.assertEqual(list(image_param.shared_init[i].shape), s_shape) + self.assertEqual(list(image_param.shared_init[i]().shape), s_shape) self.assertEqual(test_tensor.dim(), 4) self.assertEqual(test_tensor.size(0), batch) @@ -635,6 +692,219 @@ def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None: self.assertEqual(test_tensor.size(2), size[0]) self.assertEqual(test_tensor.size(3), size[1]) + def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping SharedImage JIT module test due to insufficient Torch" + + " version." + ) + + shared_shapes = ( + (128 // 2, 128 // 2), + (7, 3, 128 // 4, 128 // 4), + (3, 128 // 8, 128 // 8), + (2, 4, 128 // 8, 128 // 8), + (1, 3, 128 // 16, 128 // 16), + (2, 2, 128 // 16, 128 // 16), + ) + batch = 6 + channels = 3 + size = (224, 224) + test_input = torch.ones(batch, channels, size[0], size[1]) # noqa: E731 + test_param = images.SimpleTensorParameterization(test_input) + image_param = images.SharedImage( + shapes=shared_shapes, parameterization=test_param + ) + jit_image_param = torch.jit.script(image_param) + test_tensor = jit_image_param() + + self.assertEqual(test_tensor.dim(), 4) + self.assertEqual(test_tensor.size(0), batch) + self.assertEqual(test_tensor.size(1), channels) + self.assertEqual(test_tensor.size(2), size[0]) + self.assertEqual(test_tensor.size(3), size[1]) + + +class TestStackImage(BaseTest): + def test_stackimage_init(self) -> None: + size = (4, 4) + fft_param_1 = images.FFTImage(size=size) + fft_param_2 = images.FFTImage(size=size) + param_list = [fft_param_1, fft_param_2] + stack_param = images.StackImage(parameterizations=param_list) + for image_param in stack_param.parameterizations: + self.assertIsInstance(image_param, images.FFTImage) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertTrue(image_param().requires_grad) + + def test_stackimage_forward(self) -> None: + size = (4, 4) + fft_param_1 = images.FFTImage(size=size) + fft_param_2 = images.FFTImage(size=size) + param_list = [fft_param_1, fft_param_2] + stack_param = images.StackImage(parameterizations=param_list) + for image_param in stack_param.parameterizations: + self.assertIsInstance(image_param, images.FFTImage) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertTrue(image_param().requires_grad) + + output_tensor = stack_param() + self.assertEqual(list(output_tensor.shape), [2, 3] + list(size)) + self.assertTrue(output_tensor.requires_grad) + self.assertIsNone(stack_param.output_device) + + def test_stackimage_forward_diff_image_params(self) -> None: + size = (4, 4) + fft_param = images.FFTImage(size=size) + pixel_param = images.PixelImage(size=size) + param_list = [fft_param, pixel_param] + + stack_param = images.StackImage(parameterizations=param_list) + + type_list = [images.FFTImage, images.PixelImage] + for image_param, expected_type in zip(stack_param.parameterizations, type_list): + self.assertIsInstance(image_param, expected_type) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertTrue(image_param().requires_grad) + + output_tensor = stack_param() + self.assertEqual(list(output_tensor.shape), [2, 3] + list(size)) + self.assertTrue(output_tensor.requires_grad) + self.assertIsNone(stack_param.output_device) + + def test_stackimage_forward_diff_image_params_and_tensor_with_grad(self) -> None: + size = (4, 4) + fft_param = images.FFTImage(size=size) + pixel_param = images.PixelImage(size=size) + test_tensor = torch.nn.Parameter(torch.ones(1, 3, size[0], size[1])) + param_list = [fft_param, pixel_param, test_tensor] + + stack_param = images.StackImage(parameterizations=param_list) + + type_list = [ + images.FFTImage, + images.PixelImage, + images.SimpleTensorParameterization, + ] + for image_param, expected_type in zip(stack_param.parameterizations, type_list): + self.assertIsInstance(image_param, expected_type) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertTrue(image_param().requires_grad) + + output_tensor = stack_param() + self.assertEqual(list(output_tensor.shape), [3, 3] + list(size)) + self.assertTrue(output_tensor.requires_grad) + self.assertIsNone(stack_param.output_device) + + def test_stackimage_forward_diff_image_params_and_tensor_no_grad(self) -> None: + size = (4, 4) + fft_param = images.FFTImage(size=size) + pixel_param = images.PixelImage(size=size) + test_tensor = torch.ones(1, 3, size[0], size[1]) + param_list = [fft_param, pixel_param, test_tensor] + + stack_param = images.StackImage(parameterizations=param_list) + + type_list = [ + images.FFTImage, + images.PixelImage, + images.SimpleTensorParameterization, + ] + for image_param, expected_type in zip(stack_param.parameterizations, type_list): + self.assertIsInstance(image_param, expected_type) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + + self.assertTrue(stack_param.parameterizations[0]().requires_grad) + self.assertTrue(stack_param.parameterizations[1]().requires_grad) + self.assertFalse(stack_param.parameterizations[2]().requires_grad) + + output_tensor = stack_param() + self.assertEqual(list(output_tensor.shape), [3, 3] + list(size)) + self.assertTrue(output_tensor.requires_grad) + self.assertIsNone(stack_param.output_device) + + def test_stackimage_forward_multi_gpu(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping StackImage multi GPU test due to not supporting CUDA." + ) + if torch.cuda.device_count() == 1: + raise unittest.SkipTest( + "Skipping StackImage multi GPU device test due to not having enough" + + " GPUs available." + ) + size = (4, 4) + + num_cuda_devices = torch.cuda.device_count() + param_list, device_list = [], [] + + fft_param = images.FFTImage(size=size).cpu() + param_list.append(fft_param) + device_list.append(torch.device("cpu")) + + for i in range(num_cuda_devices - 1): + device = torch.device("cuda:" + str(i)) + device_list.append(device) + fft_param = images.FFTImage(size=size).to(device) + param_list.append(fft_param) + + output_device = torch.device("cuda:" + str(num_cuda_devices - 1)) + stack_param = images.StackImage( + parameterizations=param_list, output_device=output_device + ) + + for image_param, torch_device in zip( + stack_param.parameterizations, device_list + ): + self.assertIsInstance(image_param, images.FFTImage) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertEqual(image_param().device, torch_device) + self.assertTrue(image_param().requires_grad) + + output_tensor = stack_param() + self.assertEqual( + list(output_tensor.shape), [len(param_list)] + [3] + list(size) + ) + self.assertTrue(output_tensor.requires_grad) + self.assertEqual(stack_param().device, output_device) + + def test_stackimage_forward_multi_device_cpu_gpu(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping StackImage multi device test due to not supporting CUDA." + ) + size = (4, 4) + param_list, device_list = [], [] + + fft_param = images.FFTImage(size=size).cpu() + param_list.append(fft_param) + device_list.append(torch.device("cpu")) + + device = torch.device("cuda:0") + device_list.append(device) + fft_param = images.FFTImage(size=size).to(device) + param_list.append(fft_param) + + output_device = torch.device("cuda:0") + stack_param = images.StackImage( + parameterizations=param_list, output_device=output_device + ) + + for image_param, torch_device in zip( + stack_param.parameterizations, device_list + ): + self.assertIsInstance(image_param, images.FFTImage) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertEqual(image_param().device, torch_device) + self.assertTrue(image_param().requires_grad) + + output_tensor = stack_param() + self.assertEqual( + list(output_tensor.shape), [len(param_list)] + [3] + list(size) + ) + self.assertTrue(output_tensor.requires_grad) + self.assertEqual(stack_param().device, output_device) + class TestNaturalImage(BaseTest): def test_natural_image_0(self) -> None: From 26d9519ff247905712fcaf1c3735c585de45a8c0 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Fri, 31 Dec 2021 08:14:25 -0700 Subject: [PATCH 2/8] Fix test version checks --- tests/optim/param/test_images.py | 35 +++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index f0fc045934..b2904c1f9e 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -693,7 +693,7 @@ def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None: self.assertEqual(test_tensor.size(3), size[1]) def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None: - if torch.__version__ <= "1.2.0": + if torch.__version__ <= "1.8.0": raise unittest.SkipTest( "Skipping SharedImage JIT module test due to insufficient Torch" + " version." @@ -727,6 +727,10 @@ def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None: class TestStackImage(BaseTest): def test_stackimage_init(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping StackImage init test due to insufficient Torch version." + ) size = (4, 4) fft_param_1 = images.FFTImage(size=size) fft_param_2 = images.FFTImage(size=size) @@ -738,6 +742,10 @@ def test_stackimage_init(self) -> None: self.assertTrue(image_param().requires_grad) def test_stackimage_forward(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping StackImage forward test due to insufficient Torch version." + ) size = (4, 4) fft_param_1 = images.FFTImage(size=size) fft_param_2 = images.FFTImage(size=size) @@ -754,6 +762,11 @@ def test_stackimage_forward(self) -> None: self.assertIsNone(stack_param.output_device) def test_stackimage_forward_diff_image_params(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping StackImage forward with diff image params test due to" + + " insufficient Torch version." + ) size = (4, 4) fft_param = images.FFTImage(size=size) pixel_param = images.PixelImage(size=size) @@ -773,6 +786,11 @@ def test_stackimage_forward_diff_image_params(self) -> None: self.assertIsNone(stack_param.output_device) def test_stackimage_forward_diff_image_params_and_tensor_with_grad(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping StackImage forward with diff image params and tensor with" + + " grad test due to insufficient Torch version." + ) size = (4, 4) fft_param = images.FFTImage(size=size) pixel_param = images.PixelImage(size=size) @@ -797,6 +815,11 @@ def test_stackimage_forward_diff_image_params_and_tensor_with_grad(self) -> None self.assertIsNone(stack_param.output_device) def test_stackimage_forward_diff_image_params_and_tensor_no_grad(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping StackImage forward with diff image params and tensor with" + + " no grad test due to insufficient Torch version." + ) size = (4, 4) fft_param = images.FFTImage(size=size) pixel_param = images.PixelImage(size=size) @@ -824,6 +847,11 @@ def test_stackimage_forward_diff_image_params_and_tensor_no_grad(self) -> None: self.assertIsNone(stack_param.output_device) def test_stackimage_forward_multi_gpu(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping StackImage forward multi GPU test due to insufficient" + + " Torch version." + ) if not torch.cuda.is_available(): raise unittest.SkipTest( "Skipping StackImage multi GPU test due to not supporting CUDA." @@ -869,6 +897,11 @@ def test_stackimage_forward_multi_gpu(self) -> None: self.assertEqual(stack_param().device, output_device) def test_stackimage_forward_multi_device_cpu_gpu(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping StackImage forward multi device test due to insufficient" + + " Torch version." + ) if not torch.cuda.is_available(): raise unittest.SkipTest( "Skipping StackImage multi device test due to not supporting CUDA." From cf921baec24ad84f46ad3dfce3cbf400dd745a6f Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Fri, 31 Dec 2021 14:27:58 -0700 Subject: [PATCH 3/8] More tests & new AugmentedImageParameterization base class * Added `AugmentedImageParameterization` class to use a base for `SharedImage` and `StackImage`. * Removed `PixelImage`'s 3 channel assert, as there was no reason for limitation. * Added tests for `InputParameterization`, `ImageParameterization`, & `AugmentedImageParameterization`. --- captum/optim/_param/image/images.py | 24 ++++++++-- tests/optim/param/test_images.py | 73 ++++++++++++++++++++++++----- 2 files changed, 80 insertions(+), 17 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index faeef9c81d..376ca782bb 100644 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -133,6 +133,10 @@ class ImageParameterization(InputParameterization): pass +class AugmentedImageParameterization(ImageParameterization): + pass + + class FFTImage(ImageParameterization): """ Parameterize an image using inverse real 2D FFT @@ -305,8 +309,6 @@ def __init__( assert init.dim() == 3 or init.dim() == 4 if init.dim() == 3: init = init.unsqueeze(0) - assert init.shape[1] == 3, "PixelImage init should have 3 channels, " - f"input has {init.shape[1]} channels." self.image = nn.Parameter(init) def forward(self) -> torch.Tensor: @@ -432,10 +434,14 @@ def __init__(self, tensor: torch.Tensor = None) -> None: self.tensor = tensor def forward(self) -> torch.Tensor: + """ + Returns: + tensor (torch.Tensor): The tensor stored during initialization. + """ return self.tensor -class SharedImage(ImageParameterization): +class SharedImage(AugmentedImageParameterization): """ Share some image parameters across the batch to increase spatial alignment, by using interpolated lower resolution tensors. @@ -585,6 +591,10 @@ def _interpolate_tensor( return x def forward(self) -> torch.Tensor: + """ + Returns: + output (torch.Tensor): An NCHW image parameterization output. + """ image = self.parameterization() x = [ self._interpolate_tensor( @@ -606,7 +616,7 @@ def forward(self) -> torch.Tensor: return output.refine_names("B", "C", "H", "W") -class StackImage(ImageParameterization): +class StackImage(AugmentedImageParameterization): """ Stack multiple NCHW image parameterizations along their batch dimensions. """ @@ -691,7 +701,9 @@ def __init__( channels: int = 3, batch: int = 1, init: Optional[torch.Tensor] = None, - parameterization: ImageParameterization = FFTImage, + parameterization: Union[ + ImageParameterization, AugmentedImageParameterization + ] = FFTImage, squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"), decorrelate_init: bool = True, @@ -758,9 +770,11 @@ def forward(self) -> torch.Tensor: "ImageTensor", "InputParameterization", "ImageParameterization", + "AugmentedImageParameterization", "FFTImage", "PixelImage", "LaplacianImage", "SharedImage", + "StackImage", "NaturalImage", ] diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index b2904c1f9e..4225d05959 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -79,7 +79,7 @@ def test_export_and_open_local_image(self) -> None: self.assertTrue(torch.is_tensor(new_tensor)) assertTensorAlmostEqual(self, image_tensor, new_tensor) - def test_natural_image_cuda(self) -> None: + def test_image_tensor_cuda(self) -> None: if not torch.cuda.is_available(): raise unittest.SkipTest( "Skipping ImageTensor CUDA test due to not supporting CUDA." @@ -88,7 +88,31 @@ def test_natural_image_cuda(self) -> None: self.assertTrue(image_t.is_cuda) +class TestInputParameterization(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.InputParameterization, torch.nn.Module)) + + +class TestImageParameterization(BaseTest): + def test_subclass(self) -> None: + self.assertTrue( + issubclass(images.ImageParameterization, images.InputParameterization) + ) + + +class TestAugmentedImageParameterization(BaseTest): + def test_subclass(self) -> None: + self.assertTrue( + issubclass( + images.AugmentedImageParameterization, images.ImageParameterization + ) + ) + + class TestFFTImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.FFTImage, images.ImageParameterization)) + def test_pytorch_fftfreq(self) -> None: image = images.FFTImage((1, 1)) _, _, fftfreq = image.get_fft_funcs() @@ -219,6 +243,9 @@ def test_fftimage_forward_init_batch(self) -> None: class TestPixelImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.PixelImage, images.ImageParameterization)) + def test_pixelimage_random(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -251,17 +278,6 @@ def test_pixelimage_init(self) -> None: self.assertEqual(image_param.image.size(3), size[1]) assertTensorAlmostEqual(self, image_param.image, init_tensor, 0) - def test_pixelimage_init_error(self) -> None: - if torch.__version__ <= "1.2.0": - raise unittest.SkipTest( - "Skipping PixelImage init due to insufficient Torch version." - ) - size = (224, 224) - channels = 2 - init_tensor = torch.randn(channels, *size) - with self.assertRaises(AssertionError): - images.PixelImage(size=size, channels=channels, init=init_tensor) - def test_pixelimage_random_forward(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -298,6 +314,9 @@ def test_pixelimage_init_forward(self) -> None: class TestLaplacianImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.LaplacianImage, images.ImageParameterization)) + def test_laplacianimage_random_forward(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -326,6 +345,13 @@ def test_laplacianimage_init(self) -> None: class TestSimpleTensorParameterization(BaseTest): + def test_subclass(self) -> None: + self.assertTrue( + issubclass( + images.SimpleTensorParameterization, images.ImageParameterization + ) + ) + def test_simple_tensor_parameterization_no_grad(self) -> None: test_input = torch.randn(1, 3, 4, 4) image_param = images.SimpleTensorParameterization(test_input) @@ -337,6 +363,11 @@ def test_simple_tensor_parameterization_no_grad(self) -> None: self.assertFalse(image_param.tensor.requires_grad) def test_simple_tensor_parameterization_jit_module_no_grad(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping SimpleTensorParameterization JIT module test due to" + + " insufficient Torch version." + ) test_input = torch.randn(1, 3, 4, 4) image_param = images.SimpleTensorParameterization(test_input) jit_image_param = torch.jit.script(image_param) @@ -356,6 +387,11 @@ def test_simple_tensor_parameterization_with_grad(self) -> None: self.assertTrue(image_param.tensor.requires_grad) def test_simple_tensor_parameterization_jit_module_with_grad(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping SimpleTensorParameterization JIT module test due to" + + " insufficient Torch version." + ) test_input = torch.nn.Parameter(torch.randn(1, 3, 4, 4)) image_param = images.SimpleTensorParameterization(test_input) jit_image_param = torch.jit.script(image_param) @@ -383,6 +419,11 @@ def test_simple_tensor_parameterization_cuda(self) -> None: class TestSharedImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue( + issubclass(images.SharedImage, images.AugmentedImageParameterization) + ) + def test_sharedimage_get_offset_single_number(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -726,6 +767,11 @@ def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None: class TestStackImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue( + issubclass(images.StackImage, images.AugmentedImageParameterization) + ) + def test_stackimage_init(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -940,6 +986,9 @@ def test_stackimage_forward_multi_device_cpu_gpu(self) -> None: class TestNaturalImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.NaturalImage, images.ImageParameterization)) + def test_natural_image_0(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( From 9b7ce93eda157d326c50ddd5d6f9cccf35192e97 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Fri, 31 Dec 2021 19:40:40 -0700 Subject: [PATCH 4/8] Add JIT support for SharedImage._interpolate_tensor * Added JIT support for SharedImage's interpolation operations. * Unfortunately, JIT support required me to separate SharedImage's bilinear and trilinear resizing into separate functions as Union's of tuples are currently broken. Union support was also a newer addition, so now SharedImage can support older PyTorch versions as well. --- captum/optim/_param/image/images.py | 92 +++++++++++++++++++++++++---- tests/optim/param/test_images.py | 77 ++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 12 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 376ca782bb..8a03cda706 100644 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -454,7 +454,12 @@ class SharedImage(AugmentedImageParameterization): https://distill.pub/2018/differentiable-parameterizations/ """ - __constants__ = ["offset", "_supports_is_scripting"] + __constants__ = [ + "offset", + "_supports_is_scripting", + "_has_align_corners", + "_has_recompute_scale_factor", + ] def __init__( self, @@ -491,6 +496,8 @@ def __init__( # Check & store whether or not we can use torch.jit.is_scripting() self._supports_is_scripting = torch.__version__ >= "1.6.0" + self._has_align_corners = torch.__version__ >= "1.3.0" + self._has_recompute_scale_factor = torch.__version__ >= "1.6.0" def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]: """ @@ -551,7 +558,75 @@ def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: A.append(x) return A - @torch.jit.ignore + def _interpolate_bilinear( + self, + x: torch.Tensor, + size: Tuple[int, int], + ) -> torch.Tensor: + """ + Perform interpolation without any warnings. + + Args: + + x (torch.Tensor): The NCHW tensor to resize. + size (tuple of int): The desired output size to resize the input + to, with a format of: [height, width]. + + Returns: + x (torch.Tensor): A resized NCHW tensor. + """ + assert x.dim() == 4 + assert len(size) == 2 + + if self._has_align_corners: + if self._has_recompute_scale_factor: + x = F.interpolate( + x, + size=size, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) + else: + x = F.interpolate(x, size=size, mode="bilinear", align_corners=False) + else: + x = F.interpolate(x, size=size, mode="bilinear") + return x + + def _interpolate_trilinear( + self, + x: torch.Tensor, + size: Tuple[int, int, int], + ) -> torch.Tensor: + """ + Perform interpolation without any warnings. + + Args: + + x (torch.Tensor): The NCHW tensor to resize. + size (tuple of int): The desired output size to resize the input + to, with a format of: [channels, height, width]. + + Returns: + x (torch.Tensor): A resized NCHW tensor. + """ + x = x.unsqueeze(0) + assert x.dim() == 5 + if self._has_align_corners: + if self._has_recompute_scale_factor: + x = F.interpolate( + x, + size=size, + mode="trilinear", + align_corners=False, + recompute_scale_factor=False, + ) + else: + x = F.interpolate(x, size=size, mode="trilinear", align_corners=False) + else: + x = F.interpolate(x, size=size, mode="trilinear") + return x.squeeze(0) + def _interpolate_tensor( self, x: torch.Tensor, batch: int, channels: int, height: int, width: int ) -> torch.Tensor: @@ -572,21 +647,14 @@ def _interpolate_tensor( """ if x.size(1) == channels: - mode = "bilinear" size = (height, width) + x = self._interpolate_bilinear(x, size=size) else: - mode = "trilinear" - x = x.unsqueeze(0) size = (channels, height, width) - x = F.interpolate(x, size=size, mode=mode) - x = x.squeeze(0) if len(size) == 3 else x + x = self._interpolate_trilinear(x, size=size) if x.size(0) != batch: x = x.permute(1, 0, 2, 3) - x = F.interpolate( - x.unsqueeze(0), - size=(batch, x.size(2), x.size(3)), - mode="trilinear", - ).squeeze(0) + x = self._interpolate_trilinear(x, size=(batch, x.size(2), x.size(3))) x = x.permute(1, 0, 2, 3) return x diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 4225d05959..4da9726ed9 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -424,6 +424,74 @@ def test_subclass(self) -> None: issubclass(images.SharedImage, images.AugmentedImageParameterization) ) + def test_sharedimage_interpolate_bilinear(self) -> None: + shared_shapes = (128 // 2, 128 // 2) + test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731 + image_param = images.SharedImage( + shapes=shared_shapes, parameterization=test_param + ) + + size = (224, 128) + test_input = torch.randn(1, 3, 128, 128) + + test_output = image_param._interpolate_bilinear(test_input.clone(), size=size) + expected_output = torch.nn.functional.interpolate( + test_input.clone(), size=size, mode="bilinear" + ) + assertTensorAlmostEqual(self, test_output, expected_output, 0.0) + + size = (128, 128) + test_input = torch.randn(1, 3, 224, 224) + + test_output = image_param._interpolate_bilinear(test_input.clone(), size=size) + expected_output = torch.nn.functional.interpolate( + test_input.clone(), size=size, mode="bilinear" + ) + assertTensorAlmostEqual(self, test_output, expected_output, 0.0) + + def test_sharedimage_interpolate_trilinear(self) -> None: + shared_shapes = (128 // 2, 128 // 2) + test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731 + image_param = images.SharedImage( + shapes=shared_shapes, parameterization=test_param + ) + + size = (3, 224, 128) + test_input = torch.randn(1, 1, 128, 128) + + test_output = image_param._interpolate_trilinear(test_input.clone(), size=size) + expected_output = torch.nn.functional.interpolate( + test_input.clone().unsqueeze(0), size=size, mode="trilinear" + ).squeeze(0) + assertTensorAlmostEqual(self, test_output, expected_output, 0.0) + + size = (2, 128, 128) + test_input = torch.randn(1, 4, 224, 224) + + test_output = image_param._interpolate_trilinear(test_input.clone(), size=size) + expected_output = torch.nn.functional.interpolate( + test_input.clone().unsqueeze(0), size=size, mode="trilinear" + ).squeeze(0) + assertTensorAlmostEqual(self, test_output, expected_output, 0.0) + + def test_torch_version_check(self) -> None: + shared_shapes = (128 // 2, 128 // 2) + test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731 + image_param = images.SharedImage( + shapes=shared_shapes, parameterization=test_param + ) + + has_align_corners = torch.__version__ >= "1.3.0" + self.assertEqual(image_param._has_align_corners, has_align_corners) + + has_recompute_scale_factor = torch.__version__ >= "1.6.0" + self.assertEqual( + image_param._has_recompute_scale_factor, has_recompute_scale_factor + ) + + supports_is_scripting = torch.__version__ >= "1.6.0" + self.assertEqual(image_param._supports_is_scripting, supports_is_scripting) + def test_sharedimage_get_offset_single_number(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -772,6 +840,15 @@ def test_subclass(self) -> None: issubclass(images.StackImage, images.AugmentedImageParameterization) ) + def test_stackimage_torch_version_check(self) -> None: + img_param_1 = images.SimpleTensorParameterization(torch.ones(1, 3, 4, 4)) + img_param_2 = images.SimpleTensorParameterization(torch.ones(1, 3, 4, 4)) + param_list = [img_param_1, img_param_2] + stack_param = images.StackImage(parameterizations=param_list) + + supports_is_scripting = torch.__version__ >= "1.6.0" + self.assertEqual(stack_param._supports_is_scripting, supports_is_scripting) + def test_stackimage_init(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( From 8f782afb0dc73a8fc212c43bc5577d776f345228 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 1 Jan 2022 11:47:53 -0700 Subject: [PATCH 5/8] Add dim variable to StackImage * Added the `dim` variable to `StackImage` so that users can choose what dimension to stack the image parameterizations across. --- captum/optim/_param/image/images.py | 17 ++++++++----- tests/optim/param/test_images.py | 39 +++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 8a03cda706..73db6286de 100644 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -689,11 +689,12 @@ class StackImage(AugmentedImageParameterization): Stack multiple NCHW image parameterizations along their batch dimensions. """ - __constants__ = ["_supports_is_scripting", "output_device"] + __constants__ = ["_supports_is_scripting", "dim", "output_device"] def __init__( self, parameterizations: List[Union[ImageParameterization, torch.Tensor]], + dim: int = 0, output_device: Optional[torch.device] = None, ) -> None: """ @@ -701,10 +702,13 @@ def __init__( parameterizations (list of ImageParameterization and torch.Tensor): A list of image parameterizations to stack across their batch dimensions. - output_device (torch.device): If the parameterizations are on different - devices, then their outputs will be moved to the device specified by - this variable. Default is set to None with the expectation that all - parameterizations are on the same device. + dim (int, optional): Optionally specify the dim to concatinate + parameterization outputs on. Default is set to the batch dimension. + Default: 0 + output_device (torch.device, optional): If the parameterizations are on + different devices, then their outputs will be moved to the device + specified by this variable. Default is set to None with the expectation + that all parameterizations are on the same device. Default: None """ super().__init__() @@ -721,6 +725,7 @@ def __init__( for p in parameterizations ] self.parameterizations = torch.nn.ModuleList(parameterizations) + self.dim = dim self.output_device = output_device # Check & store whether or not we can use torch.jit.is_scripting() @@ -743,7 +748,7 @@ def forward(self) -> torch.Tensor: assert all([im.shape == P[0].shape for im in P]) assert all([im.device == P[0].device for im in P]) - image = torch.cat(P, 0) + image = torch.cat(P, dim=self.dim) if self._supports_is_scripting: if torch.jit.is_scripting(): return image diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 4da9726ed9..2e2f36b9b5 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -424,6 +424,28 @@ def test_subclass(self) -> None: issubclass(images.SharedImage, images.AugmentedImageParameterization) ) + def test_sharedimage_init(self) -> None: + shared_shapes = ( + (1, 3, 128 // 2, 128 // 2), + (1, 3, 128 // 4, 128 // 4), + (1, 3, 128 // 8, 128 // 8), + ) + test_param = images.SimpleTensorParameterization(torch.ones(4, 3, 4, 4)) + shared_param = images.SharedImage( + shapes=shared_shapes, parameterization=test_param + ) + + self.assertIsInstance(shared_param.shared_init, torch.nn.ModuleList) + self.assertEqual(len(shared_param.shared_init), len(shared_shapes)) + for shared_init, shape in zip(shared_param.shared_init, shared_shapes): + self.assertIsInstance(shared_init, images.SimpleTensorParameterization) + self.assertEqual(list(shared_init().shape), list(shape)) + + self.assertIsInstance( + shared_param.parameterization, images.SimpleTensorParameterization + ) + self.assertIsNone(shared_param.offset) + def test_sharedimage_interpolate_bilinear(self) -> None: shared_shapes = (128 // 2, 128 // 2) test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731 @@ -859,11 +881,28 @@ def test_stackimage_init(self) -> None: fft_param_2 = images.FFTImage(size=size) param_list = [fft_param_1, fft_param_2] stack_param = images.StackImage(parameterizations=param_list) + + self.assertIsInstance(stack_param.parameterizations, torch.nn.ModuleList) + self.assertEqual(len(stack_param.parameterizations), 2) + self.assertEqual(stack_param.dim, 0) + for image_param in stack_param.parameterizations: self.assertIsInstance(image_param, images.FFTImage) self.assertEqual(list(image_param().shape), [1, 3] + list(size)) self.assertTrue(image_param().requires_grad) + def test_stackimage_dim(self) -> None: + img_param_r = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) + img_param_g = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) + img_param_b = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) + param_list = [img_param_r, img_param_g, img_param_b] + stack_param = images.StackImage(parameterizations=param_list, dim=1) + + self.assertEqual(stack_param.dim, 1) + + test_output = stack_param() + self.assertEqual(list(test_output.shape), [1, 3, 4, 4]) + def test_stackimage_forward(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( From 9bf9ec20e86ba50d75e91c5e90ed27ed7cb38d3c Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 1 Jan 2022 12:08:41 -0700 Subject: [PATCH 6/8] Fix test version --- tests/optim/param/test_images.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 2e2f36b9b5..75e20544ad 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -892,6 +892,10 @@ def test_stackimage_init(self) -> None: self.assertTrue(image_param().requires_grad) def test_stackimage_dim(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping StackImage dim test due to insufficient Torch version." + ) img_param_r = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) img_param_g = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) img_param_b = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) From 79d6d74a6c9b2d3c3ad8644d8a33cfd265a8d077 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 9 May 2022 12:08:18 -0600 Subject: [PATCH 7/8] AugmentedImageParameterization -> ImageParameterization --- captum/optim/_param/image/images.py | 13 +++---------- tests/optim/param/test_images.py | 17 ++--------------- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 71b61b10fa..cb47521dbe 100644 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -133,10 +133,6 @@ class ImageParameterization(InputParameterization): pass -class AugmentedImageParameterization(ImageParameterization): - pass - - class FFTImage(ImageParameterization): """ Parameterize an image using inverse real 2D FFT @@ -461,7 +457,7 @@ def forward(self) -> torch.Tensor: return self.tensor -class SharedImage(AugmentedImageParameterization): +class SharedImage(ImageParameterization): """ Share some image parameters across the batch to increase spatial alignment, by using interpolated lower resolution tensors. @@ -704,7 +700,7 @@ def forward(self) -> torch.Tensor: return output.refine_names("B", "C", "H", "W") -class StackImage(AugmentedImageParameterization): +class StackImage(ImageParameterization): """ Stack multiple NCHW image parameterizations along their batch dimensions. """ @@ -794,9 +790,7 @@ def __init__( channels: int = 3, batch: int = 1, init: Optional[torch.Tensor] = None, - parameterization: Union[ - ImageParameterization, AugmentedImageParameterization - ] = FFTImage, + parameterization: ImageParameterization = FFTImage, squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"), decorrelate_init: bool = True, @@ -884,7 +878,6 @@ def forward(self) -> torch.Tensor: "ImageTensor", "InputParameterization", "ImageParameterization", - "AugmentedImageParameterization", "FFTImage", "PixelImage", "LaplacianImage", diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index d9bdd32c5b..d5f5180325 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -100,15 +100,6 @@ def test_subclass(self) -> None: ) -class TestAugmentedImageParameterization(BaseTest): - def test_subclass(self) -> None: - self.assertTrue( - issubclass( - images.AugmentedImageParameterization, images.ImageParameterization - ) - ) - - class TestFFTImage(BaseTest): def test_subclass(self) -> None: self.assertTrue(issubclass(images.FFTImage, images.ImageParameterization)) @@ -568,9 +559,7 @@ def test_simple_tensor_parameterization_cuda(self) -> None: class TestSharedImage(BaseTest): def test_subclass(self) -> None: - self.assertTrue( - issubclass(images.SharedImage, images.AugmentedImageParameterization) - ) + self.assertTrue(issubclass(images.SharedImage, images.ImageParameterization)) def test_sharedimage_init(self) -> None: shared_shapes = ( @@ -1006,9 +995,7 @@ def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None: class TestStackImage(BaseTest): def test_subclass(self) -> None: - self.assertTrue( - issubclass(images.StackImage, images.AugmentedImageParameterization) - ) + self.assertTrue(issubclass(images.StackImage, images.ImageParameterization)) def test_stackimage_torch_version_check(self) -> None: img_param_1 = images.SimpleTensorParameterization(torch.ones(1, 3, 4, 4)) From 9c3dd87c9019ed1a5d06b86fd3d84f003302fe07 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 15 May 2022 09:52:53 -0600 Subject: [PATCH 8/8] Remove unused code --- captum/optim/_param/image/images.py | 67 +++++++++------------------- tests/optim/param/test_images.py | 68 +---------------------------- 2 files changed, 22 insertions(+), 113 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 37033c0a5f..af0bf2486b 100644 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -461,12 +461,7 @@ class SharedImage(ImageParameterization): https://distill.pub/2018/differentiable-parameterizations/ """ - __constants__ = [ - "offset", - "_supports_is_scripting", - "_has_align_corners", - "_has_recompute_scale_factor", - ] + __constants__ = ["offset"] def __init__( self, @@ -501,11 +496,6 @@ def __init__( self.parameterization = parameterization self.offset = self._get_offset(offset, len(A)) if offset is not None else None - # Check & store whether or not we can use torch.jit.is_scripting() - self._supports_is_scripting = torch.__version__ >= "1.6.0" - self._has_align_corners = torch.__version__ >= "1.3.0" - self._has_recompute_scale_factor = torch.__version__ >= "1.6.0" - def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]: """ Given offset values, return a list of offsets for _apply_offset to use. @@ -585,19 +575,13 @@ def _interpolate_bilinear( assert x.dim() == 4 assert len(size) == 2 - if self._has_align_corners: - if self._has_recompute_scale_factor: - x = F.interpolate( - x, - size=size, - mode="bilinear", - align_corners=False, - recompute_scale_factor=False, - ) - else: - x = F.interpolate(x, size=size, mode="bilinear", align_corners=False) - else: - x = F.interpolate(x, size=size, mode="bilinear") + x = F.interpolate( + x, + size=size, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) return x def _interpolate_trilinear( @@ -619,19 +603,13 @@ def _interpolate_trilinear( """ x = x.unsqueeze(0) assert x.dim() == 5 - if self._has_align_corners: - if self._has_recompute_scale_factor: - x = F.interpolate( - x, - size=size, - mode="trilinear", - align_corners=False, - recompute_scale_factor=False, - ) - else: - x = F.interpolate(x, size=size, mode="trilinear", align_corners=False) - else: - x = F.interpolate(x, size=size, mode="trilinear") + x = F.interpolate( + x, + size=size, + mode="trilinear", + align_corners=False, + recompute_scale_factor=False, + ) return x.squeeze(0) def _interpolate_tensor( @@ -685,9 +663,8 @@ def forward(self) -> torch.Tensor: x = self._apply_offset(x) output = image + torch.cat(x, 0).sum(0, keepdim=True) - if self._supports_is_scripting: - if torch.jit.is_scripting(): - return output + if torch.jit.is_scripting(): + return output return output.refine_names("B", "C", "H", "W") @@ -696,7 +673,7 @@ class StackImage(ImageParameterization): Stack multiple NCHW image parameterizations along their batch dimensions. """ - __constants__ = ["_supports_is_scripting", "dim", "output_device"] + __constants__ = ["dim", "output_device"] def __init__( self, @@ -735,9 +712,6 @@ def __init__( self.dim = dim self.output_device = output_device - # Check & store whether or not we can use torch.jit.is_scripting() - self._supports_is_scripting = torch.__version__ >= "1.6.0" - def forward(self) -> torch.Tensor: """ Returns: @@ -756,9 +730,8 @@ def forward(self) -> torch.Tensor: assert all([im.device == P[0].device for im in P]) image = torch.cat(P, dim=self.dim) - if self._supports_is_scripting: - if torch.jit.is_scripting(): - return image + if torch.jit.is_scripting(): + return image return image.refine_names("B", "C", "H", "W") diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index b3576d86f0..617d34a3a3 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -419,7 +419,7 @@ def test_simple_tensor_parameterization_no_grad(self) -> None: self.assertFalse(image_param.tensor.requires_grad) def test_simple_tensor_parameterization_jit_module_no_grad(self) -> None: - if torch.__version__ <= "1.8.0": + if version.parse(torch.__version__) <= version.parse("1.8.0"): raise unittest.SkipTest( "Skipping SimpleTensorParameterization JIT module test due to" + " insufficient Torch version." @@ -550,24 +550,6 @@ def test_sharedimage_interpolate_trilinear(self) -> None: ).squeeze(0) assertTensorAlmostEqual(self, test_output, expected_output, 0.0) - def test_torch_version_check(self) -> None: - shared_shapes = (128 // 2, 128 // 2) - test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731 - image_param = images.SharedImage( - shapes=shared_shapes, parameterization=test_param - ) - - has_align_corners = torch.__version__ >= "1.3.0" - self.assertEqual(image_param._has_align_corners, has_align_corners) - - has_recompute_scale_factor = torch.__version__ >= "1.6.0" - self.assertEqual( - image_param._has_recompute_scale_factor, has_recompute_scale_factor - ) - - supports_is_scripting = torch.__version__ >= "1.6.0" - self.assertEqual(image_param._supports_is_scripting, supports_is_scripting) - def test_sharedimage_get_offset_single_number(self) -> None: shared_shapes = (128 // 2, 128 // 2) test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731 @@ -825,7 +807,7 @@ def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None: self.assertEqual(test_tensor.size(3), size[1]) def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None: - if torch.__version__ <= "1.8.0": + if version.parse(torch.__version__) <= version.parse("1.8.0"): raise unittest.SkipTest( "Skipping SharedImage JIT module test due to insufficient Torch" + " version." @@ -861,20 +843,7 @@ class TestStackImage(BaseTest): def test_subclass(self) -> None: self.assertTrue(issubclass(images.StackImage, images.ImageParameterization)) - def test_stackimage_torch_version_check(self) -> None: - img_param_1 = images.SimpleTensorParameterization(torch.ones(1, 3, 4, 4)) - img_param_2 = images.SimpleTensorParameterization(torch.ones(1, 3, 4, 4)) - param_list = [img_param_1, img_param_2] - stack_param = images.StackImage(parameterizations=param_list) - - supports_is_scripting = torch.__version__ >= "1.6.0" - self.assertEqual(stack_param._supports_is_scripting, supports_is_scripting) - def test_stackimage_init(self) -> None: - if torch.__version__ <= "1.2.0": - raise unittest.SkipTest( - "Skipping StackImage init test due to insufficient Torch version." - ) size = (4, 4) fft_param_1 = images.FFTImage(size=size) fft_param_2 = images.FFTImage(size=size) @@ -891,10 +860,6 @@ def test_stackimage_init(self) -> None: self.assertTrue(image_param().requires_grad) def test_stackimage_dim(self) -> None: - if torch.__version__ <= "1.2.0": - raise unittest.SkipTest( - "Skipping StackImage dim test due to insufficient Torch version." - ) img_param_r = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) img_param_g = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) img_param_b = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) @@ -907,10 +872,6 @@ def test_stackimage_dim(self) -> None: self.assertEqual(list(test_output.shape), [1, 3, 4, 4]) def test_stackimage_forward(self) -> None: - if torch.__version__ <= "1.2.0": - raise unittest.SkipTest( - "Skipping StackImage forward test due to insufficient Torch version." - ) size = (4, 4) fft_param_1 = images.FFTImage(size=size) fft_param_2 = images.FFTImage(size=size) @@ -927,11 +888,6 @@ def test_stackimage_forward(self) -> None: self.assertIsNone(stack_param.output_device) def test_stackimage_forward_diff_image_params(self) -> None: - if torch.__version__ <= "1.2.0": - raise unittest.SkipTest( - "Skipping StackImage forward with diff image params test due to" - + " insufficient Torch version." - ) size = (4, 4) fft_param = images.FFTImage(size=size) pixel_param = images.PixelImage(size=size) @@ -951,11 +907,6 @@ def test_stackimage_forward_diff_image_params(self) -> None: self.assertIsNone(stack_param.output_device) def test_stackimage_forward_diff_image_params_and_tensor_with_grad(self) -> None: - if torch.__version__ <= "1.2.0": - raise unittest.SkipTest( - "Skipping StackImage forward with diff image params and tensor with" - + " grad test due to insufficient Torch version." - ) size = (4, 4) fft_param = images.FFTImage(size=size) pixel_param = images.PixelImage(size=size) @@ -980,11 +931,6 @@ def test_stackimage_forward_diff_image_params_and_tensor_with_grad(self) -> None self.assertIsNone(stack_param.output_device) def test_stackimage_forward_diff_image_params_and_tensor_no_grad(self) -> None: - if torch.__version__ <= "1.2.0": - raise unittest.SkipTest( - "Skipping StackImage forward with diff image params and tensor with" - + " no grad test due to insufficient Torch version." - ) size = (4, 4) fft_param = images.FFTImage(size=size) pixel_param = images.PixelImage(size=size) @@ -1012,11 +958,6 @@ def test_stackimage_forward_diff_image_params_and_tensor_no_grad(self) -> None: self.assertIsNone(stack_param.output_device) def test_stackimage_forward_multi_gpu(self) -> None: - if torch.__version__ <= "1.2.0": - raise unittest.SkipTest( - "Skipping StackImage forward multi GPU test due to insufficient" - + " Torch version." - ) if not torch.cuda.is_available(): raise unittest.SkipTest( "Skipping StackImage multi GPU test due to not supporting CUDA." @@ -1062,11 +1003,6 @@ def test_stackimage_forward_multi_gpu(self) -> None: self.assertEqual(stack_param().device, output_device) def test_stackimage_forward_multi_device_cpu_gpu(self) -> None: - if torch.__version__ <= "1.2.0": - raise unittest.SkipTest( - "Skipping StackImage forward multi device test due to insufficient" - + " Torch version." - ) if not torch.cuda.is_available(): raise unittest.SkipTest( "Skipping StackImage multi device test due to not supporting CUDA."