From d60055988d572c02a8b408ba84698b308e182172 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 26 Apr 2021 11:58:03 -0600 Subject: [PATCH 01/33] Fix the Inception5h model's download link --- captum/optim/models/_image/inception_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/captum/optim/models/_image/inception_v1.py b/captum/optim/models/_image/inception_v1.py index 65e73e32d9..102581c095 100644 --- a/captum/optim/models/_image/inception_v1.py +++ b/captum/optim/models/_image/inception_v1.py @@ -7,7 +7,7 @@ GS_SAVED_WEIGHTS_URL = ( "https://github.com/pytorch/captum/raw/" - + "optim-wip/captum/optim/_models/inception5h.pth" + + "optim-wip/captum/optim/models/_image/inception5h.pth" ) From 8612feec489c2bbe6b1f8b92ec8a0b3b5fc9875d Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 26 Apr 2021 12:04:59 -0600 Subject: [PATCH 02/33] Fix black error --- captum/attr/_core/feature_ablation.py | 8 ++++---- captum/metrics/_core/infidelity.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index ecd096ba6e..f7a03e21dd 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -62,7 +62,7 @@ def attribute( additional_forward_args: Any = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, - **kwargs: Any + **kwargs: Any, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: @@ -321,7 +321,7 @@ def attribute( baselines, feature_mask, perturbations_per_eval, - **kwargs + **kwargs, ): # modified_eval dimensions: 1D tensor with length # equal to #num_examples * #features in batch @@ -373,7 +373,7 @@ def _ablation_generator( baselines, input_mask, perturbations_per_eval, - **kwargs + **kwargs, ): """ This method is a generator which yields each perturbation to be evaluated @@ -458,7 +458,7 @@ def _ablation_generator( baseline, num_features_processed, num_features_processed + current_num_ablated_features, - **extra_args + **extra_args, ) # current_features[i] has dimension diff --git a/captum/metrics/_core/infidelity.py b/captum/metrics/_core/infidelity.py index b396e20543..2f06e497d8 100644 --- a/captum/metrics/_core/infidelity.py +++ b/captum/metrics/_core/infidelity.py @@ -68,7 +68,7 @@ def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable: def default_perturb_func( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None ): - r"""""" + r""" """ inputs_perturbed = ( pertub_func(inputs, baselines) if baselines is not None @@ -380,7 +380,7 @@ def _generate_perturbations( """ def call_perturb_func(): - r"""""" + r""" """ baselines_pert = None inputs_pert: Union[Tensor, Tuple[Tensor, ...]] if len(inputs_expanded) == 1: From d31d60d605b3d532288f668feb1e52f68867d058 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 26 Apr 2021 15:35:51 -0600 Subject: [PATCH 03/33] Ensure history has no gradient --- captum/optim/_core/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/captum/optim/_core/optimization.py b/captum/optim/_core/optimization.py index 63f30a016e..65fb4d7c5c 100644 --- a/captum/optim/_core/optimization.py +++ b/captum/optim/_core/optimization.py @@ -138,7 +138,7 @@ def optimize( while stop_criteria(step, self, history, optimizer): optimizer.zero_grad() loss_value = loss_summarize_fn(self.loss()) - history.append(loss_value) + history.append(loss_value.clone().detach().cpu()) loss_value.backward() optimizer.step() step += 1 From bd3c7fe35e8527535cb1315c928a26c52e46406f Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 26 Apr 2021 15:53:25 -0600 Subject: [PATCH 04/33] Remove .cpu() to improve optimization speed --- captum/optim/_core/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/captum/optim/_core/optimization.py b/captum/optim/_core/optimization.py index 65fb4d7c5c..75d613a336 100644 --- a/captum/optim/_core/optimization.py +++ b/captum/optim/_core/optimization.py @@ -138,7 +138,7 @@ def optimize( while stop_criteria(step, self, history, optimizer): optimizer.zero_grad() loss_value = loss_summarize_fn(self.loss()) - history.append(loss_value.clone().detach().cpu()) + history.append(loss_value.clone().detach()) loss_value.backward() optimizer.step() step += 1 From bbebf22208c7d0c6028cd976873f897f85d17ecc Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 3 May 2021 10:00:42 -0600 Subject: [PATCH 05/33] Temporarily disable nightly build tests affected by pytorch/pytorch#57421 --- tests/optim/core/test_optimization.py | 2 ++ tests/optim/param/test_images.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/optim/core/test_optimization.py b/tests/optim/core/test_optimization.py index 581cc194fc..434b6c24fc 100644 --- a/tests/optim/core/test_optimization.py +++ b/tests/optim/core/test_optimization.py @@ -9,6 +9,7 @@ class TestInputOptimization(BaseTest): + @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_input_optimization(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -22,6 +23,7 @@ def test_input_optimization(self) -> None: self.assertTrue(history[0] > history[-1]) self.assertTrue(len(history) == n_steps) + @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_input_optimization_param(self) -> None: """Test for optimizing param without model""" if torch.__version__ <= "1.2.0": diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index fb354e82dd..50becf0e26 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -16,6 +16,7 @@ class TestImageTensor(BaseTest): + @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_repr(self) -> None: self.assertEqual(str(images.ImageTensor()), "ImageTensor([])") @@ -565,6 +566,7 @@ def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None: class TestNaturalImage(BaseTest): + @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_natural_image_0(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -573,7 +575,7 @@ def test_natural_image_0(self) -> None: image_param = images.NaturalImage(size=(1, 1)) image_np = image_param.forward().detach().numpy() assertArraysAlmostEqual(image_np, np.ones_like(image_np) * 0.5) - + @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_natural_image_1(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( From 0fb9eb5074d41eed9f074f7e30310a31d4584c72 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 3 May 2021 10:06:06 -0600 Subject: [PATCH 06/33] Fix flake8 error --- tests/optim/param/test_images.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 50becf0e26..ddbd2fbd2e 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -575,6 +575,7 @@ def test_natural_image_0(self) -> None: image_param = images.NaturalImage(size=(1, 1)) image_np = image_param.forward().detach().numpy() assertArraysAlmostEqual(image_np, np.ones_like(image_np) * 0.5) + @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_natural_image_1(self) -> None: if torch.__version__ <= "1.2.0": From 7c453e0e76a0550df57715a93ee5fe7d23bce9b7 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 6 May 2021 07:35:45 -0600 Subject: [PATCH 07/33] Remove ImageTensor test skips & add new tests * Remove `ImageTensor` test skips as the `torch.Tensor`'s `__new__` function has been fixed. * Add tests for `ImageTensor` functions. * Removed old `AlphaChannelLoss` code. --- captum/optim/_param/image/transforms.py | 47 ------------------------- tests/optim/core/test_optimization.py | 2 -- tests/optim/param/test_images.py | 47 +++++++++++++++++++++++-- tests/optim/param/test_transforms.py | 19 ---------- 4 files changed, 44 insertions(+), 71 deletions(-) diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index 049366e1c4..899695594b 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -503,53 +503,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) -class AlphaChannelLoss(nn.Module): - """ - TODO: Fix AlphaChannelLoss - Transform for calculating alpha channel loss, without altering the input tensor. - Loss values are calculated in such a way that opaque and transparent regions of - the tensor are automatically balanced. - ​ - See: https://distill.pub/2018/differentiable-parameterizations/ - Mordvintsev, et al., "Differentiable Image Parameterizations", Distill, 2018. - ​ - Args: - scale (float, sequence): Tuple of rescaling values to randomly select from. - crop_size (int, sequence, int, optional): The desired cropped output size - for secondary alpha channel loss. - background (tensor, optional): An NCHW image tensor to be used as the - alpha channel's background. - """ - - def __init__( - self, - scale: NumSeqOrTensorType, - crop_size: Optional[Tuple[int, int]] = None, - background: Optional[torch.Tensor] = None, - ) -> None: - raise NotImplementedError # We are not ready for this - super().__init__() - self.random_scale = RandomScale(scale=scale) - self.crop_size = crop_size - self.random_crop = RandomCrop(crop_size) - self.blend_alpha = BlendAlpha(background=background) - self.loss = 0 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - assert x.dim() == 4 # Should be of shape (batch, channel, height, width) - assert x.size(1) == 4 # Channel dim should be rgba - - x_shifted = torch.cat([self.blend_alpha(x.clone()), x.clone()[:, 3:]], 1) - - x_shifted = self.random_scale(x_shifted) - x_shifted_crop = self.random_crop(x_shifted) - - self.loss = (1.0 - x_shifted[:, 3:].mean()) + ( - (1.0 - x_shifted_crop[:, 3:].mean()) * 0.5 - ) - return x - - __all__ = [ "BlendAlpha", "IgnoreAlpha", diff --git a/tests/optim/core/test_optimization.py b/tests/optim/core/test_optimization.py index 434b6c24fc..581cc194fc 100644 --- a/tests/optim/core/test_optimization.py +++ b/tests/optim/core/test_optimization.py @@ -9,7 +9,6 @@ class TestInputOptimization(BaseTest): - @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_input_optimization(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -23,7 +22,6 @@ def test_input_optimization(self) -> None: self.assertTrue(history[0] > history[-1]) self.assertTrue(len(history) == n_steps) - @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_input_optimization_param(self) -> None: """Test for optimizing param without model""" if torch.__version__ <= "1.2.0": diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index ddbd2fbd2e..a2c5b292cd 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -16,10 +16,53 @@ class TestImageTensor(BaseTest): - @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_repr(self) -> None: self.assertEqual(str(images.ImageTensor()), "ImageTensor([])") + def test_new(self) -> None: + x = torch.ones(5) + test_tensor = images.ImageTensor(x) + self.assertTrue(torch.is_tensor(test_tensor)) + self.assertEqual(x.shape, test_tensor.shape) + + def test_new_numpy(self) -> None: + x = torch.ones(5).numpy() + test_tensor = images.ImageTensor(x) + self.assertTrue(torch.is_tensor(test_tensor)) + self.assertEqual(x.shape, test_tensor.shape) + + def test_new_list(self) -> None: + x = torch.ones(5).tolist() + test_tensor = images.ImageTensor(x) + self.assertTrue(torch.is_tensor(test_tensor)) + self.assertEqual(x.shape, test_tensor.shape) + + def test_torch_function(self) -> None: + x = torch.ones(5) + image_tensor = images.ImageTensor(x) + image_tensor = (image_tensor * 1) * torch.ones(5) + self.assertEqual(image_tensor.sum().item(), torch.ones(5).sum().item()) + + def test_load_image_from_url(self) -> None: + img_url = ( + "https://github.com/pytorch/captum" + + "/raw/master/website/static/img/captum_logo.png" + ) + new_tensor = images.ImageTensor().open(img_url) + self.assertTrue(torch.is_tensor(new_tensor)) + self.assertEqual(list(new_tensor.shape), [3, 54, 208]) + + def test_export_and_open_local_image(self) -> None: + x = torch.ones(1, 3, 5, 5) + image_tensor = images.ImageTensor(x) + + filename = "image_tensor.jpg" + image_tensor.export(filename) + new_tensor = images.ImageTensor().open(filename) + + self.assertTrue(torch.is_tensor(new_tensor)) + assertTensorAlmostEqual(self, image_tensor, new_tensor) + def test_natural_image_cuda(self) -> None: if not torch.cuda.is_available(): raise unittest.SkipTest( @@ -566,7 +609,6 @@ def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None: class TestNaturalImage(BaseTest): - @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_natural_image_0(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -576,7 +618,6 @@ def test_natural_image_0(self) -> None: image_np = image_param.forward().detach().numpy() assertArraysAlmostEqual(image_np, np.ones_like(image_np) * 0.5) - @unittest.skipIf(torch.__version__ > "1.8.1", "Bug in PyTorch nightly build") def test_natural_image_1(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( diff --git a/tests/optim/param/test_transforms.py b/tests/optim/param/test_transforms.py index 7dff299c73..ade3ba37e7 100644 --- a/tests/optim/param/test_transforms.py +++ b/tests/optim/param/test_transforms.py @@ -650,22 +650,3 @@ def test_random_crop(self) -> None: x_out = crop_transform(x) self.assertEqual(list(x_out.shape), [1, 4, 160, 160]) - - -class TestAlphaChannelLoss(BaseTest): - def test_alpha_channel_loss_forward(self) -> None: - raise unittest.SkipTest( - "Skipping AlphaChannelLoss test until function is ready." - ) - # crop_size = [160, 160] - # scale = [0.6, 0.7, 0.8, 0.9, 1.0, 1.1] - - # alpha_loss_transform = transforms.AlphaChannelLoss( - # scale=scale, crop_size=crop_size - # ) - # x = torch.randn(1, 4, 224, 224) - - # x_out = alpha_loss_transform(x) - - # assertTensorAlmostEqual(self, x_out, x, 0) - # self.assertNotEqual(alpha_loss_transforms.loss, 0) From 5d0d143e7ba76ead89e297e66cbf11b42d64ea7d Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 6 May 2021 07:53:31 -0600 Subject: [PATCH 08/33] Fix ImageTensor __new__ list test --- tests/optim/param/test_images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index a2c5b292cd..07a5abe4e9 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -32,8 +32,8 @@ def test_new_numpy(self) -> None: self.assertEqual(x.shape, test_tensor.shape) def test_new_list(self) -> None: - x = torch.ones(5).tolist() - test_tensor = images.ImageTensor(x) + x = torch.ones(5) + test_tensor = images.ImageTensor(x.tolist()) self.assertTrue(torch.is_tensor(test_tensor)) self.assertEqual(x.shape, test_tensor.shape) From 50455bf9bcbc113071d4cc2806d44a3a2450db4d Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 12 May 2021 12:00:09 -0600 Subject: [PATCH 09/33] Fix NaturalImage device bug * Fix `NaturalImage` device bug. * Set `decorrelate_init` default to `False`. * Fix `NaturalImage` size type. --- captum/optim/_param/image/images.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 354fc1cb9f..b121706efe 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -418,17 +418,17 @@ class NaturalImage(ImageParameterization): def __init__( self, - size: Tuple[int, int] = [224, 224], + size: Tuple[int, int] = (224, 224), channels: int = 3, batch: int = 1, init: Optional[torch.Tensor] = None, parameterization: ImageParameterization = FFTImage, squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, - decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"), - decorrelate_init: bool = True, + decorrelation_module: Optional[nn.Module] = None, + decorrelate_init: bool = False, ) -> None: super().__init__() - self.decorrelate = decorrelation_module + self.decorrelate = decorrelation_module or ToRGB(transform="klt") if init is not None: assert init.dim() == 3 or init.dim() == 4 if decorrelate_init: From c9ece95e86b1abc24f7449b1801121f32a95cf30 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 12 May 2021 12:14:37 -0600 Subject: [PATCH 10/33] Set decorrelate_init default back to True --- captum/optim/_param/image/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index b121706efe..e43c56ab01 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -425,7 +425,7 @@ def __init__( parameterization: ImageParameterization = FFTImage, squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, decorrelation_module: Optional[nn.Module] = None, - decorrelate_init: bool = False, + decorrelate_init: bool = True, ) -> None: super().__init__() self.decorrelate = decorrelation_module or ToRGB(transform="klt") From 4bd8e5bccbfea23ad2ee1a1364d09da036ede413 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 12 May 2021 12:43:21 -0600 Subject: [PATCH 11/33] Check for presence of Pillow / PIL library in ImageTensor applicable tests --- tests/optim/param/test_images.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 07a5abe4e9..5ac71eea7e 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -44,6 +44,14 @@ def test_torch_function(self) -> None: self.assertEqual(image_tensor.sum().item(), torch.ones(5).sum().item()) def test_load_image_from_url(self) -> None: + try: + from PIL import Image # noqa: F401 + + except (ImportError, AssertionError): + raise unittest.SkipTest( + "Module Pillow / PIL not found, skipping ImageTensor load from url" + + " test" + ) img_url = ( "https://github.com/pytorch/captum" + "/raw/master/website/static/img/captum_logo.png" @@ -53,6 +61,14 @@ def test_load_image_from_url(self) -> None: self.assertEqual(list(new_tensor.shape), [3, 54, 208]) def test_export_and_open_local_image(self) -> None: + try: + from PIL import Image # noqa: F401 + + except (ImportError, AssertionError): + raise unittest.SkipTest( + "Module Pillow / PIL not found, skipping ImageTensor export and save" + + " local image test" + ) x = torch.ones(1, 3, 5, 5) image_tensor = images.ImageTensor(x) From 71a370d865cc904808f5d254b439c35850a2fa00 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 17 May 2021 09:05:09 -0600 Subject: [PATCH 12/33] Update Conda installation script to latest version --- scripts/install_via_conda.sh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/scripts/install_via_conda.sh b/scripts/install_via_conda.sh index b647293f47..6eedd15f27 100755 --- a/scripts/install_via_conda.sh +++ b/scripts/install_via_conda.sh @@ -34,8 +34,8 @@ else fi # install other deps -conda install -y numpy sphinx pytest flake8 ipywidgets ipython -conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typehints mypy flask isort +conda install -y numpy sphinx pytest flake8 ipywidgets ipython scikit-learn +conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typehints mypy flask isort flask-compress # install node/yarn for insights build conda install -y -c conda-forge yarn @@ -44,5 +44,4 @@ conda update -y --no-channel-priority -c conda-forge nodejs # build insights and install captum -# TODO: remove CI=false when we want React warnings treated as errors -CI=false BUILD_INSIGHTS=1 python setup.py develop +BUILD_INSIGHTS=1 python setup.py develop From e7fcdc017dc9fde9371078ddd98bf78da6500a60 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 19 May 2021 12:11:47 -0600 Subject: [PATCH 13/33] Add SkipLayer to models __init__ --- captum/optim/models/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/captum/optim/models/__init__.py b/captum/optim/models/__init__.py index 635d1eb5b6..a970e68ec4 100755 --- a/captum/optim/models/__init__.py +++ b/captum/optim/models/__init__.py @@ -1,5 +1,6 @@ from ._common import ( # noqa: F401 RedirectedReluLayer, + SkipLayer, collect_activations, get_model_layers, replace_layers, @@ -10,6 +11,7 @@ __all__ = [ "RedirectedReluLayer", + "SkipLayer", "collect_activations", "get_model_layers", "replace_layers", From 015890eec540a95e5ac8dd1a737c4336cedef18e Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 19 May 2021 12:27:24 -0600 Subject: [PATCH 14/33] Make SkipLayer work if there are any additional init or forward arguments --- captum/optim/models/_common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index d8976b4bc3..4eb94b1a9a 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -202,7 +202,10 @@ class SkipLayer(torch.nn.Module): This layer is made to take the place of nonlinear activation layers like ReLU. """ - def forward(self, x: torch.Tensor) -> torch.Tensor: + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return x From 1bcac931f784165ffdbd3da5a38d251c27427e61 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 20 May 2021 15:41:39 -0600 Subject: [PATCH 15/33] Minor correction to optimize's loss summarizer setup It now fits with the `Optional` type hint that it was given. --- captum/optim/_core/optimization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/captum/optim/_core/optimization.py b/captum/optim/_core/optimization.py index 75d613a336..27c5bf3162 100644 --- a/captum/optim/_core/optimization.py +++ b/captum/optim/_core/optimization.py @@ -111,7 +111,7 @@ def optimize( self, stop_criteria: Optional[StopCriteria] = None, optimizer: Optional[optim.Optimizer] = None, - loss_summarize_fn: Optional[Callable] = default_loss_summarize, + loss_summarize_fn: Optional[Callable] = None, lr: float = 0.025, ) -> torch.Tensor: r"""Optimize input based on loss function and objectives. @@ -131,6 +131,7 @@ def optimize( stop_criteria = stop_criteria or n_steps(512) optimizer = optimizer or optim.Adam(self.parameters(), lr=lr) assert isinstance(optimizer, optim.Optimizer) + loss_summarize_fn = loss_summarize_fn or default_loss_summarize history = [] step = 0 From 811a269aafce3a9d8e84ea0720d1b97d83fcc6cc Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Fri, 21 May 2021 17:59:15 -0600 Subject: [PATCH 16/33] Fix _rand_select bug * Fix issue where the final value in a list was not selectable. * Fix error when lists have a size of 1. --- captum/optim/_param/image/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index 899695594b..8d5f43ec1c 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -217,9 +217,9 @@ def _rand_select( transform_values: NumSeqOrTensorType, ) -> Union[int, float, torch.Tensor]: """ - Randomly return a value from the provided tuple or list + Randomly return a single value from the provided tuple, list, or tensor. """ - n = torch.randint(low=0, high=len(transform_values) - 1, size=[1]).item() + n = torch.randint(low=0, high=len(transform_values), size=[1]).item() return transform_values[n] From 1f2d42192c555885b964c072e71d3863da1f8165 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 22 May 2021 09:24:09 -0600 Subject: [PATCH 17/33] Increase number of steps in optimization test * Tests showed that using only 5 iterations was no longer sufficient to ensure the final loss values were less than the first loss values. --- tests/optim/core/test_optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/optim/core/test_optimization.py b/tests/optim/core/test_optimization.py index 581cc194fc..7a8ff6aa92 100644 --- a/tests/optim/core/test_optimization.py +++ b/tests/optim/core/test_optimization.py @@ -17,7 +17,7 @@ def test_input_optimization(self) -> None: model = BasicModel_ConvNet_Optim() loss_fn = opt.loss.ChannelActivation(model.layer, 0) obj = opt.InputOptimization(model, loss_function=loss_fn) - n_steps = 5 + n_steps = 25 history = obj.optimize(opt.optimization.n_steps(n_steps, show_progress=False)) self.assertTrue(history[0] > history[-1]) self.assertTrue(len(history) == n_steps) From 8227b797c7c5d187015bf1776287c55877345460 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 22 May 2021 14:21:36 -0600 Subject: [PATCH 18/33] Fix FFTImage support for images with odd width values --- captum/optim/_param/image/images.py | 4 +--- tests/optim/helpers/numpy_image.py | 15 --------------- tests/optim/param/test_images.py | 10 ++++++++++ 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index e43c56ab01..7884b2b116 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -138,9 +138,7 @@ def __init__( def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor: """Computes 2D spectrum frequencies.""" fy = self.torch_fftfreq(height)[:, None] - # on odd input dimensions we need to keep one additional frequency - wadd = 2 if width % 2 == 1 else 1 - fx = self.torch_fftfreq(width)[: width // 2 + wadd] + fx = self.torch_fftfreq(width)[: width // 2 + 1] return torch.sqrt((fx * fx) + (fy * fy)) def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]: diff --git a/tests/optim/helpers/numpy_image.py b/tests/optim/helpers/numpy_image.py index 3a25e501b7..ea60005197 100644 --- a/tests/optim/helpers/numpy_image.py +++ b/tests/optim/helpers/numpy_image.py @@ -3,17 +3,6 @@ import numpy as np -def setup_batch(x: np.ndarray, batch: int = 1, dim: int = 3) -> np.ndarray: - assert batch > 0 - x = x[None, :] if x.ndim == dim and batch == 1 else x - x = ( - np.stack([np.copy(x) for b in range(batch)]) - if x.ndim == dim and batch > 1 - else x - ) - return x - - class FFTImage: """Parameterize an image using inverse real 2D FFT""" @@ -67,10 +56,6 @@ def rfft2d_freqs(height: int, width: int) -> np.ndarray: fx = np.fft.fftfreq(width)[: width // 2 + wadd] return np.sqrt((fx * fx) + (fy * fy)) - def set_image(self, correlated_image: np.ndarray) -> None: - coeffs = np.fft.rfftn(correlated_image, s=self.size).view("(2,)float") - self.fourier_coeffs = coeffs / self.spectrum_scale - def forward(self) -> np.ndarray: h, w = self.size scaled_spectrum = self.fourier_coeffs * self.spectrum_scale diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 5ac71eea7e..1a2db55024 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -150,6 +150,16 @@ def test_fftimage_forward_init_randn_channels(self) -> None: self.assertEqual(fftimage_tensor.detach().numpy().shape, fftimage_array.shape) + def test_fftimage_forward_randn_init_width_odd(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping FFTImage test due to insufficient Torch version." + ) + fftimage = images.FFTImage(size=(512, 405)) + self.assertEqual(list(fftimage.spectrum_scale.shape), [1, 512, 203, 1]) + fftimage_tensor = fftimage().detach() + self.assertEqual(list(fftimage_tensor.shape), [1, 3, 512, 405]) + def test_fftimage_forward_init_chw(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( From f10fa86acf0f91f131057e62813b883fb85152f9 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 22 May 2021 14:40:24 -0600 Subject: [PATCH 19/33] Fix test_rfft2d_freqs & add more SkipLayer tests --- tests/optim/models/test_models_common.py | 12 ++++++++++++ tests/optim/param/test_images.py | 8 +++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/optim/models/test_models_common.py b/tests/optim/models/test_models_common.py index 08c2d0a7d3..f6418b8d6c 100644 --- a/tests/optim/models/test_models_common.py +++ b/tests/optim/models/test_models_common.py @@ -270,6 +270,18 @@ def test_skip_layer(self) -> None: output_tensor = layer(x) assertTensorAlmostEqual(self, x, output_tensor, 0) + def test_skip_layer_ignore_init_variables(self) -> None: + layer = model_utils.SkipLayer(0, inplace=True) + x = torch.randn(1, 3, 4, 4) + output_tensor = layer(x) + assertTensorAlmostEqual(self, x, output_tensor, 0) + + def test_skip_layer_ignore_forward_variables(self) -> None: + layer = model_utils.SkipLayer() + x = torch.randn(1, 3, 4, 4) + output_tensor = layer(x, 1, inverse=True) + assertTensorAlmostEqual(self, x, output_tensor, 0) + class TestSkipLayersFunction(BaseTest): def test_skip_layers(self) -> None: diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 1a2db55024..2400e61159 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -98,9 +98,11 @@ def test_rfft2d_freqs(self) -> None: height = 2 width = 3 image = images.FFTImage((1, 1)) - assertArraysAlmostEqual( - image.rfft2d_freqs(height, width).numpy(), - numpy_image.FFTImage.rfft2d_freqs(height, width), + + assertTensorAlmostEqual( + self, + image.rfft2d_freqs(height, width), + torch.tensor([[0.0000, 0.3333], [0.5000, 0.6009]]), ) def test_fftimage_forward_randn_init(self) -> None: From 91575242d58b1fa82c40a4d8aafb9bfb6f65685b Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 23 May 2021 09:08:11 -0600 Subject: [PATCH 20/33] Remove duplicate imports --- captum/optim/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/captum/optim/__init__.py b/captum/optim/__init__.py index 0f00b53645..dc16251393 100755 --- a/captum/optim/__init__.py +++ b/captum/optim/__init__.py @@ -12,7 +12,6 @@ show, weights_to_heatmap_2d, ) -from captum.optim._utils.reducer import ChannelReducer, posneg # noqa: F401 __all__ = [ "InputOptimization", From 635932163c09be8d0b03d4c8506bf5cf0c60bc4e Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 23 May 2021 10:24:49 -0600 Subject: [PATCH 21/33] Make it possible to load RGBA images with ImageTensor --- captum/optim/_param/image/images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 7884b2b116..c5b1a82528 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -35,13 +35,13 @@ def __new__( return super().__new__(cls, x, *args, **kwargs) @classmethod - def open(cls, path: str, scale: float = 255.0) -> "ImageTensor": + def open(cls, path: str, scale: float = 255.0, mode: str = "RGB") -> "ImageTensor": if path.startswith("https://") or path.startswith("http://"): response = requests.get(path, stream=True) img = Image.open(response.raw) else: img = Image.open(path) - img_np = np.array(img.convert("RGB")).astype(np.float32) + img_np = np.array(img.convert(mode)).astype(np.float32) return cls(img_np.transpose(2, 0, 1) / scale) def __repr__(self) -> str: From e4a1310f8931dde798b047b60679e1ae74c79ff8 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 24 May 2021 09:11:03 -0600 Subject: [PATCH 22/33] Remove unused line: 'h, w = self.size' --- captum/optim/_param/image/images.py | 1 - 1 file changed, 1 deletion(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index c5b1a82528..e2ac50e5b7 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -179,7 +179,6 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor: return torch_rfft, torch_irfft, torch_fftfreq def forward(self) -> torch.Tensor: - h, w = self.size scaled_spectrum = self.fourier_coeffs * self.spectrum_scale output = self.torch_irfft(scaled_spectrum) return output.refine_names("B", "C", "H", "W") From 9ca0f9bbb66f01cb872012693db9d5d39064a49a Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Tue, 25 May 2021 18:38:52 -0600 Subject: [PATCH 23/33] Change NumPy rfft2d_freqs to match PyTorch version --- tests/optim/helpers/numpy_image.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/optim/helpers/numpy_image.py b/tests/optim/helpers/numpy_image.py index ea60005197..303e7dfe73 100644 --- a/tests/optim/helpers/numpy_image.py +++ b/tests/optim/helpers/numpy_image.py @@ -51,13 +51,10 @@ def __init__( def rfft2d_freqs(height: int, width: int) -> np.ndarray: """Computes 2D spectrum frequencies.""" fy = np.fft.fftfreq(height)[:, None] - # on odd input dimensions we need to keep one additional frequency - wadd = 2 if width % 2 == 1 else 1 - fx = np.fft.fftfreq(width)[: width // 2 + wadd] + fx = np.fft.fftfreq(width)[: width // 2 + 1] return np.sqrt((fx * fx) + (fy * fy)) def forward(self) -> np.ndarray: - h, w = self.size scaled_spectrum = self.fourier_coeffs * self.spectrum_scale scaled_spectrum = scaled_spectrum.astype(complex) output = np.fft.irfftn(scaled_spectrum, s=self.size) From 681e3eec4de3c06de2202d05a5f50ec8b39137ac Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 26 May 2021 10:18:30 -0600 Subject: [PATCH 24/33] Fix PCA ChannelReducer test --- tests/optim/utils/test_reducer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/optim/utils/test_reducer.py b/tests/optim/utils/test_reducer.py index d01b61b9b0..2f509f6f8d 100644 --- a/tests/optim/utils/test_reducer.py +++ b/tests/optim/utils/test_reducer.py @@ -68,9 +68,8 @@ def test_channelreducer_pytorch_pca(self) -> None: ) test_input = torch.randn(1, 32, 224, 224).abs() - c_reducer = reducer.ChannelReducer( - n_components=3, reduction_alg="PCA", max_iter=100 - ) + c_reducer = reducer.ChannelReducer(n_components=3, reduction_alg="PCA") + test_output = c_reducer.fit_transform(test_input) self.assertEquals(test_output.size(0), 1) self.assertEquals(test_output.size(1), 3) From 8aaf1a2242b364ac5bbe528bec119cdceaa7817a Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 26 May 2021 11:30:39 -0600 Subject: [PATCH 25/33] Fix failing nodejs See: https://github.com/pytorch/captum/pull/675 for more details. --- scripts/install_via_conda.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/install_via_conda.sh b/scripts/install_via_conda.sh index 6eedd15f27..57fe72096d 100755 --- a/scripts/install_via_conda.sh +++ b/scripts/install_via_conda.sh @@ -40,7 +40,7 @@ conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typeh # install node/yarn for insights build conda install -y -c conda-forge yarn # nodejs should be last, otherwise other conda packages will downgrade node -conda update -y --no-channel-priority -c conda-forge nodejs +conda install -y --no-channel-priority -c conda-forge nodejs=14 # build insights and install captum From 27145c7a7f1d80f992a2b7c947f3d99ba9c16af6 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 26 May 2021 12:10:33 -0600 Subject: [PATCH 26/33] Resolve the ToRGB device issue with NaturalImage Resolve the `ToRGB` device issue as mentioned in: pytorch/captum#656 --- captum/optim/_param/image/images.py | 10 ++++++---- tests/optim/param/test_images.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index e2ac50e5b7..6bf5edf56e 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -421,15 +421,17 @@ def __init__( init: Optional[torch.Tensor] = None, parameterization: ImageParameterization = FFTImage, squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, - decorrelation_module: Optional[nn.Module] = None, + decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"), decorrelate_init: bool = True, ) -> None: super().__init__() - self.decorrelate = decorrelation_module or ToRGB(transform="klt") + self.decorrelate = ( + decorrelation_module.cpu() if decorrelation_module is not None else None + ) if init is not None: + assert not init.is_cuda assert init.dim() == 3 or init.dim() == 4 - if decorrelate_init: - assert self.decorrelate is not None + if decorrelate_init and self.decorrelate is not None: init = ( init.refine_names("B", "C", "H", "W") if init.dim() == 4 diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 2400e61159..c979c08478 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -662,3 +662,14 @@ def test_natural_image_cuda(self) -> None: ) image_param = images.NaturalImage().cuda() self.assertTrue(image_param().is_cuda) + + def test_natural_image_decorrelation_module_none(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping NaturalImage test due to insufficient Torch version." + ) + image_param = images.NaturalImage( + init=torch.ones(1, 3, 4, 4), decorrelation_module=None + ) + image = image_param.forward().detach() + assertTensorAlmostEqual(self, image, torch.ones_like(image)) From 15a7db6b44ea276b319233a0ac85b03d68f1e0e4 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 26 May 2021 12:23:13 -0600 Subject: [PATCH 27/33] Fix no color decorrelation test --- tests/optim/param/test_images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index c979c08478..525d6277aa 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -664,7 +664,7 @@ def test_natural_image_cuda(self) -> None: self.assertTrue(image_param().is_cuda) def test_natural_image_decorrelation_module_none(self) -> None: - if torch.__version__ <= "1.2.0": + if torch.__version__ <= "1.3.0": raise unittest.SkipTest( "Skipping NaturalImage test due to insufficient Torch version." ) From 02fc2d78920f0ae6c5b60eeeac6fe8c0a9c1b4c5 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Tue, 1 Jun 2021 09:12:07 -0600 Subject: [PATCH 28/33] Add ToDos and better SkipLayer docs --- captum/optim/_param/image/images.py | 7 ++++++- captum/optim/models/_common.py | 27 +++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 6bf5edf56e..dbeb8b289c 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -209,6 +209,9 @@ def forward(self) -> torch.Tensor: class LaplacianImage(ImageParameterization): + """ + TODO: Fix divison by 6 in setup_input when init is not None. + """ def __init__( self, size: Tuple[int, int] = None, @@ -386,7 +389,9 @@ def forward(self) -> torch.Tensor: class NaturalImage(ImageParameterization): - r"""Outputs an optimizable input image. + r"""TODO: Resolve device issues with default ToRGB instance, and init tensors. + + Outputs an optimizable input image. By convention, single images are CHW and float32s in [0,1]. The underlying parameterization can be decorrelated via a ToRGB transform. diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index 4eb94b1a9a..dd330ee639 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -199,13 +199,36 @@ def collect_activations( class SkipLayer(torch.nn.Module): """ - This layer is made to take the place of nonlinear activation layers like ReLU. + This layer is made to take the place of any layer that needs to be skipped over + during the forward pass. Use cases include removing nonlinear activation layers + like ReLU for circuits research. + + This layer works almost exactly the same way that nn.Indentiy does, except it also + ignores any additional arguments passed to the forward function. + + See nn.Identity for more details: + https://pytorch.org/docs/stable/generated/torch.nn.Identity.html + + Args: + args (Any): Any argument. Arguments will be safely ignored. + kwargs (Any) Any keyword argument. Arguments will be safely ignored. """ def __init__(self, *args, **kwargs) -> None: super().__init__() - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def forward( + self, x: Union[torch.Tensor, Tuple[torch.Tensor]], *args, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """ + Args: + x (torch.Tensor or tuple of torch.Tensor): The input tensor or tensors. + args (Any): Any argument. Arguments will be safely ignored. + kwargs (Any) Any keyword argument. Arguments will be safely ignored. + Returns: + x (torch.Tensor or tuple of torch.Tensor): The unmodified input tensor or + tensors. + """ return x From 7f078f936e2481068fdd142e038d77c6a740037f Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Tue, 1 Jun 2021 10:26:31 -0600 Subject: [PATCH 29/33] Fix lint and FFTImage init device --- captum/optim/_param/image/images.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index dbeb8b289c..551990b8f3 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -116,7 +116,6 @@ def __init__( ) scale = scale * ((self.size[0] * self.size[1]) ** (1 / 2)) spectrum_scale = scale[None, :, :, None] - self.register_buffer("spectrum_scale", spectrum_scale) if init is None: coeffs_shape = ( @@ -131,8 +130,10 @@ def __init__( ) # names=["C", "H_f", "W_f", "complex"] fourier_coeffs = random_coeffs / 50 else: + spectrum_scale = spectrum_scale.to(init.device) fourier_coeffs = self.torch_rfft(init) / spectrum_scale + self.register_buffer("spectrum_scale", spectrum_scale) self.fourier_coeffs = nn.Parameter(fourier_coeffs) def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor: @@ -390,7 +391,7 @@ def forward(self) -> torch.Tensor: class NaturalImage(ImageParameterization): r"""TODO: Resolve device issues with default ToRGB instance, and init tensors. - + Outputs an optimizable input image. By convention, single images are CHW and float32s in [0,1]. From d9b4620062fc077268d0a6c116ba947536fd9c43 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Tue, 1 Jun 2021 10:49:52 -0600 Subject: [PATCH 30/33] Fix NaturalImage device issues --- captum/optim/_param/image/images.py | 12 +++++++----- captum/optim/_param/image/transforms.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 551990b8f3..0841d422f8 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -213,6 +213,7 @@ class LaplacianImage(ImageParameterization): """ TODO: Fix divison by 6 in setup_input when init is not None. """ + def __init__( self, size: Tuple[int, int] = None, @@ -390,9 +391,7 @@ def forward(self) -> torch.Tensor: class NaturalImage(ImageParameterization): - r"""TODO: Resolve device issues with default ToRGB instance, and init tensors. - - Outputs an optimizable input image. + r"""Outputs an optimizable input image. By convention, single images are CHW and float32s in [0,1]. The underlying parameterization can be decorrelated via a ToRGB transform. @@ -431,11 +430,14 @@ def __init__( decorrelate_init: bool = True, ) -> None: super().__init__() + # Deep copy to avoid issue with creating class instance in the function + # signature self.decorrelate = ( - decorrelation_module.cpu() if decorrelation_module is not None else None + deepcopy(decorrelation_module) + if isinstance(decorrelation_module, ToRGB) + else decorrelation_module ) if init is not None: - assert not init.is_cuda assert init.dim() == 3 or init.dim() == 4 if decorrelate_init and self.decorrelate is not None: init = ( diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index 8d5f43ec1c..ba3c146e40 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -107,9 +107,9 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor: h, w = x.size("H"), x.size("W") flat = x.flatten(("H", "W"), "spatials") if inverse: - correct = torch.inverse(self.transform) @ flat + correct = torch.inverse(self.transform.to(x.device)) @ flat else: - correct = self.transform @ flat + correct = self.transform.to(x.device) @ flat chw = correct.unflatten("spatials", (("H", h), ("W", w))) if x.dim() == 3: From 5aaece3aadf2718d9dcf83be63ad745225fd8b0b Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Tue, 1 Jun 2021 11:43:11 -0600 Subject: [PATCH 31/33] Improve NaturalImage fix --- captum/optim/_param/image/images.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 0841d422f8..4a52099cee 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -432,11 +432,10 @@ def __init__( super().__init__() # Deep copy to avoid issue with creating class instance in the function # signature - self.decorrelate = ( - deepcopy(decorrelation_module) - if isinstance(decorrelation_module, ToRGB) - else decorrelation_module - ) + if isinstance(decorrelation_module, ToRGB): + decorrelation_module = deepcopy(decorrelation_module) + self.decorrelate = decorrelation_module + if init is not None: assert init.dim() == 3 or init.dim() == 4 if decorrelate_init and self.decorrelate is not None: From c72adc558d9c791b96dac3b9bc7f0d7c200a19be Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 2 Jun 2021 14:10:59 -0600 Subject: [PATCH 32/33] Remove redundant NaturalImage device fix --- captum/optim/_param/image/images.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 4a52099cee..b0852a512c 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -430,12 +430,7 @@ def __init__( decorrelate_init: bool = True, ) -> None: super().__init__() - # Deep copy to avoid issue with creating class instance in the function - # signature - if isinstance(decorrelation_module, ToRGB): - decorrelation_module = deepcopy(decorrelation_module) self.decorrelate = decorrelation_module - if init is not None: assert init.dim() == 3 or init.dim() == 4 if decorrelate_init and self.decorrelate is not None: From 1eae61c6e0f9298514c8a2c33a5a62db07a7154f Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 6 Jun 2021 19:11:07 -0600 Subject: [PATCH 33/33] Improve SkipLayer documentation --- captum/optim/models/_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index dd330ee639..e9fba1ba27 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -204,7 +204,8 @@ class SkipLayer(torch.nn.Module): like ReLU for circuits research. This layer works almost exactly the same way that nn.Indentiy does, except it also - ignores any additional arguments passed to the forward function. + ignores any additional arguments passed to the forward function. Any layer replaced + by SkipLayer must have the same input and output shapes. See nn.Identity for more details: https://pytorch.org/docs/stable/generated/torch.nn.Identity.html