diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index ecd096ba6e..f7a03e21dd 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -62,7 +62,7 @@ def attribute( additional_forward_args: Any = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, - **kwargs: Any + **kwargs: Any, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: @@ -321,7 +321,7 @@ def attribute( baselines, feature_mask, perturbations_per_eval, - **kwargs + **kwargs, ): # modified_eval dimensions: 1D tensor with length # equal to #num_examples * #features in batch @@ -373,7 +373,7 @@ def _ablation_generator( baselines, input_mask, perturbations_per_eval, - **kwargs + **kwargs, ): """ This method is a generator which yields each perturbation to be evaluated @@ -458,7 +458,7 @@ def _ablation_generator( baseline, num_features_processed, num_features_processed + current_num_ablated_features, - **extra_args + **extra_args, ) # current_features[i] has dimension diff --git a/captum/metrics/_core/infidelity.py b/captum/metrics/_core/infidelity.py index b396e20543..2f06e497d8 100644 --- a/captum/metrics/_core/infidelity.py +++ b/captum/metrics/_core/infidelity.py @@ -68,7 +68,7 @@ def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable: def default_perturb_func( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None ): - r"""""" + r""" """ inputs_perturbed = ( pertub_func(inputs, baselines) if baselines is not None @@ -380,7 +380,7 @@ def _generate_perturbations( """ def call_perturb_func(): - r"""""" + r""" """ baselines_pert = None inputs_pert: Union[Tensor, Tuple[Tensor, ...]] if len(inputs_expanded) == 1: diff --git a/captum/optim/__init__.py b/captum/optim/__init__.py index 0f00b53645..dc16251393 100755 --- a/captum/optim/__init__.py +++ b/captum/optim/__init__.py @@ -12,7 +12,6 @@ show, weights_to_heatmap_2d, ) -from captum.optim._utils.reducer import ChannelReducer, posneg # noqa: F401 __all__ = [ "InputOptimization", diff --git a/captum/optim/_core/optimization.py b/captum/optim/_core/optimization.py index 63f30a016e..27c5bf3162 100644 --- a/captum/optim/_core/optimization.py +++ b/captum/optim/_core/optimization.py @@ -111,7 +111,7 @@ def optimize( self, stop_criteria: Optional[StopCriteria] = None, optimizer: Optional[optim.Optimizer] = None, - loss_summarize_fn: Optional[Callable] = default_loss_summarize, + loss_summarize_fn: Optional[Callable] = None, lr: float = 0.025, ) -> torch.Tensor: r"""Optimize input based on loss function and objectives. @@ -131,6 +131,7 @@ def optimize( stop_criteria = stop_criteria or n_steps(512) optimizer = optimizer or optim.Adam(self.parameters(), lr=lr) assert isinstance(optimizer, optim.Optimizer) + loss_summarize_fn = loss_summarize_fn or default_loss_summarize history = [] step = 0 @@ -138,7 +139,7 @@ def optimize( while stop_criteria(step, self, history, optimizer): optimizer.zero_grad() loss_value = loss_summarize_fn(self.loss()) - history.append(loss_value) + history.append(loss_value.clone().detach()) loss_value.backward() optimizer.step() step += 1 diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index 354fc1cb9f..b0852a512c 100755 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -35,13 +35,13 @@ def __new__( return super().__new__(cls, x, *args, **kwargs) @classmethod - def open(cls, path: str, scale: float = 255.0) -> "ImageTensor": + def open(cls, path: str, scale: float = 255.0, mode: str = "RGB") -> "ImageTensor": if path.startswith("https://") or path.startswith("http://"): response = requests.get(path, stream=True) img = Image.open(response.raw) else: img = Image.open(path) - img_np = np.array(img.convert("RGB")).astype(np.float32) + img_np = np.array(img.convert(mode)).astype(np.float32) return cls(img_np.transpose(2, 0, 1) / scale) def __repr__(self) -> str: @@ -116,7 +116,6 @@ def __init__( ) scale = scale * ((self.size[0] * self.size[1]) ** (1 / 2)) spectrum_scale = scale[None, :, :, None] - self.register_buffer("spectrum_scale", spectrum_scale) if init is None: coeffs_shape = ( @@ -131,16 +130,16 @@ def __init__( ) # names=["C", "H_f", "W_f", "complex"] fourier_coeffs = random_coeffs / 50 else: + spectrum_scale = spectrum_scale.to(init.device) fourier_coeffs = self.torch_rfft(init) / spectrum_scale + self.register_buffer("spectrum_scale", spectrum_scale) self.fourier_coeffs = nn.Parameter(fourier_coeffs) def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor: """Computes 2D spectrum frequencies.""" fy = self.torch_fftfreq(height)[:, None] - # on odd input dimensions we need to keep one additional frequency - wadd = 2 if width % 2 == 1 else 1 - fx = self.torch_fftfreq(width)[: width // 2 + wadd] + fx = self.torch_fftfreq(width)[: width // 2 + 1] return torch.sqrt((fx * fx) + (fy * fy)) def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]: @@ -181,7 +180,6 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor: return torch_rfft, torch_irfft, torch_fftfreq def forward(self) -> torch.Tensor: - h, w = self.size scaled_spectrum = self.fourier_coeffs * self.spectrum_scale output = self.torch_irfft(scaled_spectrum) return output.refine_names("B", "C", "H", "W") @@ -212,6 +210,10 @@ def forward(self) -> torch.Tensor: class LaplacianImage(ImageParameterization): + """ + TODO: Fix divison by 6 in setup_input when init is not None. + """ + def __init__( self, size: Tuple[int, int] = None, @@ -418,7 +420,7 @@ class NaturalImage(ImageParameterization): def __init__( self, - size: Tuple[int, int] = [224, 224], + size: Tuple[int, int] = (224, 224), channels: int = 3, batch: int = 1, init: Optional[torch.Tensor] = None, @@ -431,8 +433,7 @@ def __init__( self.decorrelate = decorrelation_module if init is not None: assert init.dim() == 3 or init.dim() == 4 - if decorrelate_init: - assert self.decorrelate is not None + if decorrelate_init and self.decorrelate is not None: init = ( init.refine_names("B", "C", "H", "W") if init.dim() == 4 diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index 049366e1c4..ba3c146e40 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -107,9 +107,9 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor: h, w = x.size("H"), x.size("W") flat = x.flatten(("H", "W"), "spatials") if inverse: - correct = torch.inverse(self.transform) @ flat + correct = torch.inverse(self.transform.to(x.device)) @ flat else: - correct = self.transform @ flat + correct = self.transform.to(x.device) @ flat chw = correct.unflatten("spatials", (("H", h), ("W", w))) if x.dim() == 3: @@ -217,9 +217,9 @@ def _rand_select( transform_values: NumSeqOrTensorType, ) -> Union[int, float, torch.Tensor]: """ - Randomly return a value from the provided tuple or list + Randomly return a single value from the provided tuple, list, or tensor. """ - n = torch.randint(low=0, high=len(transform_values) - 1, size=[1]).item() + n = torch.randint(low=0, high=len(transform_values), size=[1]).item() return transform_values[n] @@ -503,53 +503,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) -class AlphaChannelLoss(nn.Module): - """ - TODO: Fix AlphaChannelLoss - Transform for calculating alpha channel loss, without altering the input tensor. - Loss values are calculated in such a way that opaque and transparent regions of - the tensor are automatically balanced. - ​ - See: https://distill.pub/2018/differentiable-parameterizations/ - Mordvintsev, et al., "Differentiable Image Parameterizations", Distill, 2018. - ​ - Args: - scale (float, sequence): Tuple of rescaling values to randomly select from. - crop_size (int, sequence, int, optional): The desired cropped output size - for secondary alpha channel loss. - background (tensor, optional): An NCHW image tensor to be used as the - alpha channel's background. - """ - - def __init__( - self, - scale: NumSeqOrTensorType, - crop_size: Optional[Tuple[int, int]] = None, - background: Optional[torch.Tensor] = None, - ) -> None: - raise NotImplementedError # We are not ready for this - super().__init__() - self.random_scale = RandomScale(scale=scale) - self.crop_size = crop_size - self.random_crop = RandomCrop(crop_size) - self.blend_alpha = BlendAlpha(background=background) - self.loss = 0 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - assert x.dim() == 4 # Should be of shape (batch, channel, height, width) - assert x.size(1) == 4 # Channel dim should be rgba - - x_shifted = torch.cat([self.blend_alpha(x.clone()), x.clone()[:, 3:]], 1) - - x_shifted = self.random_scale(x_shifted) - x_shifted_crop = self.random_crop(x_shifted) - - self.loss = (1.0 - x_shifted[:, 3:].mean()) + ( - (1.0 - x_shifted_crop[:, 3:].mean()) * 0.5 - ) - return x - - __all__ = [ "BlendAlpha", "IgnoreAlpha", diff --git a/captum/optim/models/__init__.py b/captum/optim/models/__init__.py index 635d1eb5b6..a970e68ec4 100755 --- a/captum/optim/models/__init__.py +++ b/captum/optim/models/__init__.py @@ -1,5 +1,6 @@ from ._common import ( # noqa: F401 RedirectedReluLayer, + SkipLayer, collect_activations, get_model_layers, replace_layers, @@ -10,6 +11,7 @@ __all__ = [ "RedirectedReluLayer", + "SkipLayer", "collect_activations", "get_model_layers", "replace_layers", diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index d8976b4bc3..e9fba1ba27 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -199,10 +199,37 @@ def collect_activations( class SkipLayer(torch.nn.Module): """ - This layer is made to take the place of nonlinear activation layers like ReLU. + This layer is made to take the place of any layer that needs to be skipped over + during the forward pass. Use cases include removing nonlinear activation layers + like ReLU for circuits research. + + This layer works almost exactly the same way that nn.Indentiy does, except it also + ignores any additional arguments passed to the forward function. Any layer replaced + by SkipLayer must have the same input and output shapes. + + See nn.Identity for more details: + https://pytorch.org/docs/stable/generated/torch.nn.Identity.html + + Args: + args (Any): Any argument. Arguments will be safely ignored. + kwargs (Any) Any keyword argument. Arguments will be safely ignored. """ - def forward(self, x: torch.Tensor) -> torch.Tensor: + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def forward( + self, x: Union[torch.Tensor, Tuple[torch.Tensor]], *args, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """ + Args: + x (torch.Tensor or tuple of torch.Tensor): The input tensor or tensors. + args (Any): Any argument. Arguments will be safely ignored. + kwargs (Any) Any keyword argument. Arguments will be safely ignored. + Returns: + x (torch.Tensor or tuple of torch.Tensor): The unmodified input tensor or + tensors. + """ return x diff --git a/captum/optim/models/_image/inception_v1.py b/captum/optim/models/_image/inception_v1.py index 65e73e32d9..102581c095 100644 --- a/captum/optim/models/_image/inception_v1.py +++ b/captum/optim/models/_image/inception_v1.py @@ -7,7 +7,7 @@ GS_SAVED_WEIGHTS_URL = ( "https://github.com/pytorch/captum/raw/" - + "optim-wip/captum/optim/_models/inception5h.pth" + + "optim-wip/captum/optim/models/_image/inception5h.pth" ) diff --git a/scripts/install_via_conda.sh b/scripts/install_via_conda.sh index b647293f47..57fe72096d 100755 --- a/scripts/install_via_conda.sh +++ b/scripts/install_via_conda.sh @@ -34,15 +34,14 @@ else fi # install other deps -conda install -y numpy sphinx pytest flake8 ipywidgets ipython -conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typehints mypy flask isort +conda install -y numpy sphinx pytest flake8 ipywidgets ipython scikit-learn +conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typehints mypy flask isort flask-compress # install node/yarn for insights build conda install -y -c conda-forge yarn # nodejs should be last, otherwise other conda packages will downgrade node -conda update -y --no-channel-priority -c conda-forge nodejs +conda install -y --no-channel-priority -c conda-forge nodejs=14 # build insights and install captum -# TODO: remove CI=false when we want React warnings treated as errors -CI=false BUILD_INSIGHTS=1 python setup.py develop +BUILD_INSIGHTS=1 python setup.py develop diff --git a/tests/optim/core/test_optimization.py b/tests/optim/core/test_optimization.py index 581cc194fc..7a8ff6aa92 100644 --- a/tests/optim/core/test_optimization.py +++ b/tests/optim/core/test_optimization.py @@ -17,7 +17,7 @@ def test_input_optimization(self) -> None: model = BasicModel_ConvNet_Optim() loss_fn = opt.loss.ChannelActivation(model.layer, 0) obj = opt.InputOptimization(model, loss_function=loss_fn) - n_steps = 5 + n_steps = 25 history = obj.optimize(opt.optimization.n_steps(n_steps, show_progress=False)) self.assertTrue(history[0] > history[-1]) self.assertTrue(len(history) == n_steps) diff --git a/tests/optim/helpers/numpy_image.py b/tests/optim/helpers/numpy_image.py index 3a25e501b7..303e7dfe73 100644 --- a/tests/optim/helpers/numpy_image.py +++ b/tests/optim/helpers/numpy_image.py @@ -3,17 +3,6 @@ import numpy as np -def setup_batch(x: np.ndarray, batch: int = 1, dim: int = 3) -> np.ndarray: - assert batch > 0 - x = x[None, :] if x.ndim == dim and batch == 1 else x - x = ( - np.stack([np.copy(x) for b in range(batch)]) - if x.ndim == dim and batch > 1 - else x - ) - return x - - class FFTImage: """Parameterize an image using inverse real 2D FFT""" @@ -62,17 +51,10 @@ def __init__( def rfft2d_freqs(height: int, width: int) -> np.ndarray: """Computes 2D spectrum frequencies.""" fy = np.fft.fftfreq(height)[:, None] - # on odd input dimensions we need to keep one additional frequency - wadd = 2 if width % 2 == 1 else 1 - fx = np.fft.fftfreq(width)[: width // 2 + wadd] + fx = np.fft.fftfreq(width)[: width // 2 + 1] return np.sqrt((fx * fx) + (fy * fy)) - def set_image(self, correlated_image: np.ndarray) -> None: - coeffs = np.fft.rfftn(correlated_image, s=self.size).view("(2,)float") - self.fourier_coeffs = coeffs / self.spectrum_scale - def forward(self) -> np.ndarray: - h, w = self.size scaled_spectrum = self.fourier_coeffs * self.spectrum_scale scaled_spectrum = scaled_spectrum.astype(complex) output = np.fft.irfftn(scaled_spectrum, s=self.size) diff --git a/tests/optim/models/test_models_common.py b/tests/optim/models/test_models_common.py index 08c2d0a7d3..f6418b8d6c 100644 --- a/tests/optim/models/test_models_common.py +++ b/tests/optim/models/test_models_common.py @@ -270,6 +270,18 @@ def test_skip_layer(self) -> None: output_tensor = layer(x) assertTensorAlmostEqual(self, x, output_tensor, 0) + def test_skip_layer_ignore_init_variables(self) -> None: + layer = model_utils.SkipLayer(0, inplace=True) + x = torch.randn(1, 3, 4, 4) + output_tensor = layer(x) + assertTensorAlmostEqual(self, x, output_tensor, 0) + + def test_skip_layer_ignore_forward_variables(self) -> None: + layer = model_utils.SkipLayer() + x = torch.randn(1, 3, 4, 4) + output_tensor = layer(x, 1, inverse=True) + assertTensorAlmostEqual(self, x, output_tensor, 0) + class TestSkipLayersFunction(BaseTest): def test_skip_layers(self) -> None: diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index fb354e82dd..525d6277aa 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -19,6 +19,66 @@ class TestImageTensor(BaseTest): def test_repr(self) -> None: self.assertEqual(str(images.ImageTensor()), "ImageTensor([])") + def test_new(self) -> None: + x = torch.ones(5) + test_tensor = images.ImageTensor(x) + self.assertTrue(torch.is_tensor(test_tensor)) + self.assertEqual(x.shape, test_tensor.shape) + + def test_new_numpy(self) -> None: + x = torch.ones(5).numpy() + test_tensor = images.ImageTensor(x) + self.assertTrue(torch.is_tensor(test_tensor)) + self.assertEqual(x.shape, test_tensor.shape) + + def test_new_list(self) -> None: + x = torch.ones(5) + test_tensor = images.ImageTensor(x.tolist()) + self.assertTrue(torch.is_tensor(test_tensor)) + self.assertEqual(x.shape, test_tensor.shape) + + def test_torch_function(self) -> None: + x = torch.ones(5) + image_tensor = images.ImageTensor(x) + image_tensor = (image_tensor * 1) * torch.ones(5) + self.assertEqual(image_tensor.sum().item(), torch.ones(5).sum().item()) + + def test_load_image_from_url(self) -> None: + try: + from PIL import Image # noqa: F401 + + except (ImportError, AssertionError): + raise unittest.SkipTest( + "Module Pillow / PIL not found, skipping ImageTensor load from url" + + " test" + ) + img_url = ( + "https://github.com/pytorch/captum" + + "/raw/master/website/static/img/captum_logo.png" + ) + new_tensor = images.ImageTensor().open(img_url) + self.assertTrue(torch.is_tensor(new_tensor)) + self.assertEqual(list(new_tensor.shape), [3, 54, 208]) + + def test_export_and_open_local_image(self) -> None: + try: + from PIL import Image # noqa: F401 + + except (ImportError, AssertionError): + raise unittest.SkipTest( + "Module Pillow / PIL not found, skipping ImageTensor export and save" + + " local image test" + ) + x = torch.ones(1, 3, 5, 5) + image_tensor = images.ImageTensor(x) + + filename = "image_tensor.jpg" + image_tensor.export(filename) + new_tensor = images.ImageTensor().open(filename) + + self.assertTrue(torch.is_tensor(new_tensor)) + assertTensorAlmostEqual(self, image_tensor, new_tensor) + def test_natural_image_cuda(self) -> None: if not torch.cuda.is_available(): raise unittest.SkipTest( @@ -38,9 +98,11 @@ def test_rfft2d_freqs(self) -> None: height = 2 width = 3 image = images.FFTImage((1, 1)) - assertArraysAlmostEqual( - image.rfft2d_freqs(height, width).numpy(), - numpy_image.FFTImage.rfft2d_freqs(height, width), + + assertTensorAlmostEqual( + self, + image.rfft2d_freqs(height, width), + torch.tensor([[0.0000, 0.3333], [0.5000, 0.6009]]), ) def test_fftimage_forward_randn_init(self) -> None: @@ -90,6 +152,16 @@ def test_fftimage_forward_init_randn_channels(self) -> None: self.assertEqual(fftimage_tensor.detach().numpy().shape, fftimage_array.shape) + def test_fftimage_forward_randn_init_width_odd(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping FFTImage test due to insufficient Torch version." + ) + fftimage = images.FFTImage(size=(512, 405)) + self.assertEqual(list(fftimage.spectrum_scale.shape), [1, 512, 203, 1]) + fftimage_tensor = fftimage().detach() + self.assertEqual(list(fftimage_tensor.shape), [1, 3, 512, 405]) + def test_fftimage_forward_init_chw(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -590,3 +662,14 @@ def test_natural_image_cuda(self) -> None: ) image_param = images.NaturalImage().cuda() self.assertTrue(image_param().is_cuda) + + def test_natural_image_decorrelation_module_none(self) -> None: + if torch.__version__ <= "1.3.0": + raise unittest.SkipTest( + "Skipping NaturalImage test due to insufficient Torch version." + ) + image_param = images.NaturalImage( + init=torch.ones(1, 3, 4, 4), decorrelation_module=None + ) + image = image_param.forward().detach() + assertTensorAlmostEqual(self, image, torch.ones_like(image)) diff --git a/tests/optim/param/test_transforms.py b/tests/optim/param/test_transforms.py index 7dff299c73..ade3ba37e7 100644 --- a/tests/optim/param/test_transforms.py +++ b/tests/optim/param/test_transforms.py @@ -650,22 +650,3 @@ def test_random_crop(self) -> None: x_out = crop_transform(x) self.assertEqual(list(x_out.shape), [1, 4, 160, 160]) - - -class TestAlphaChannelLoss(BaseTest): - def test_alpha_channel_loss_forward(self) -> None: - raise unittest.SkipTest( - "Skipping AlphaChannelLoss test until function is ready." - ) - # crop_size = [160, 160] - # scale = [0.6, 0.7, 0.8, 0.9, 1.0, 1.1] - - # alpha_loss_transform = transforms.AlphaChannelLoss( - # scale=scale, crop_size=crop_size - # ) - # x = torch.randn(1, 4, 224, 224) - - # x_out = alpha_loss_transform(x) - - # assertTensorAlmostEqual(self, x_out, x, 0) - # self.assertNotEqual(alpha_loss_transforms.loss, 0) diff --git a/tests/optim/utils/test_reducer.py b/tests/optim/utils/test_reducer.py index d01b61b9b0..2f509f6f8d 100644 --- a/tests/optim/utils/test_reducer.py +++ b/tests/optim/utils/test_reducer.py @@ -68,9 +68,8 @@ def test_channelreducer_pytorch_pca(self) -> None: ) test_input = torch.randn(1, 32, 224, 224).abs() - c_reducer = reducer.ChannelReducer( - n_components=3, reduction_alg="PCA", max_iter=100 - ) + c_reducer = reducer.ChannelReducer(n_components=3, reduction_alg="PCA") + test_output = c_reducer.fit_transform(test_input) self.assertEquals(test_output.size(0), 1) self.assertEquals(test_output.size(1), 3)