Skip to content

Conversation

ProGamerGov
Copy link
Contributor

@ProGamerGov ProGamerGov commented Apr 26, 2021

  • Quick fix for the model file path as it's breaking all the notebooks & tests.

  • It also looks like black was updated and now is reporting an issue with captum/attr/_core/feature_ablation.py & captum/metrics/_core/infidelity.py, so I've resolved that as well. The files in the master branch may be a bit different, but we need to at least keep the tests passing for the other areas of Captum until we merge into master.

  • Added new ImageTensor tests.

  • Added detach() to InputOptimization to fix out of memory crashes.

  • Made sure that SkipLayer can handle any addition arguments to it's init and forward functions. This adds to the usefulness of the layer and improves it beyond just copying nn.Identity.

  • Fixed various bugs.

@NarineK
Copy link
Contributor

NarineK commented Apr 29, 2021

Thank you for the fixes, @ProGamerGov ! It doesn't seem to fail on master. Are those python files different on optim-wip branch ?

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Apr 29, 2021

@NarineK I think the files are slightly different, but @vivekmig also changed the same files in the most recent PR for the master branch: 5cf38cb. When we merge with the master branch we can updated everything, but for the moment updating the files will ensure that the tests keep passing.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Apr 29, 2021

Looks like lint_test_py36_pip and lint_test_py37_conda tests are now failing due to this error:

    @staticmethod
    def __new__(
        cls: Type["ImageTensor"],
        x: Union[List, np.ndarray, torch.Tensor] = [],
        *args,
        **kwargs,
    ) -> torch.Tensor:
        if isinstance(x, torch.Tensor) and x.is_cuda:
            x.show = MethodType(cls.show, x)
            x.export = MethodType(cls.export, x)
            return x
        else:
>           return super().__new__(cls, x, *args, **kwargs)
E           TypeError: object.__new__(ImageTensor) is not safe, use Tensor.__new__()

I think the error could be a Python bug or something as I haven't seen it until today.

@ProGamerGov
Copy link
Contributor Author

I found the source of the TypeError when subclassing torch.Tensor! it's a bug with the nightly PyTorch builds. The lint_test_py36_pip test uses torch-1.9.0.dev20210501+cpu and the lint_test_py37_conda test uses torch==1.9.0.dev20210501 , and thus they both fail on tests involving ImageTensor.

I made an issue post on the PyTorch repo here: pytorch/pytorch#57421

@ProGamerGov ProGamerGov changed the title Optim-wip: Fix the Inception5h model's download link Optim-wip: Fix failing tests & model download link May 2, 2021
* Remove `ImageTensor` test skips as the `torch.Tensor`'s `__new__` function has been fixed.
* Add tests for `ImageTensor` functions.
* Removed old `AlphaChannelLoss` code.
@ProGamerGov
Copy link
Contributor Author

The issue with ImageTensor's __new__ function is now resolved in the latest version of the PyTorch nightly build!

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 11, 2021

I have discovered a new issue with the code:

Running these two lines of code:

image_param = opt.images.NaturalImage(init=torch.ones(3, 1, 1)).cuda()
image_param = opt.images.NaturalImage(init=torch.ones(3, 1, 1)).cuda()

Results in this error:

RuntimeError                              Traceback (most recent call last)

<ipython-input-3-f944a33c48db> in <module>()
      1 image_param = opt.images.NaturalImage(init=torch.ones(3, 1, 1)).cuda()
----> 2 image_param = opt.images.NaturalImage(init=torch.ones(3, 1, 1)).cuda()

2 frames

/content/captum/captum/optim/_param/image/images.py in __init__(self, size, channels, batch, init, parameterization, squash_func, decorrelation_module, decorrelate_init)
    439                     else init.refine_names("C", "H", "W")
    440                 )
--> 441                 init = self.decorrelate(init, inverse=True).rename(None)
    442             if squash_func is None:
    443 

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/content/captum/captum/optim/_param/image/transforms.py in forward(self, x, inverse)
    108         flat = x.flatten(("H", "W"), "spatials")
    109         if inverse:
--> 110             correct = torch.inverse(self.transform) @ flat
    111         else:
    112             correct = self.transform @ flat

RuntimeError: Tensor for argument #3 'mat2' is on CPU, but expected it to be on GPU (while checking arguments for addmm)

If I add a line to set the init variable to the device that the color decorrelation transform is on, then I get the following error:

RuntimeError                              Traceback (most recent call last)

<ipython-input-3-f944a33c48db> in <module>()
      1 image_param = opt.images.NaturalImage(init=torch.ones(3, 1, 1)).cuda()
----> 2 image_param = opt.images.NaturalImage(init=torch.ones(3, 1, 1)).cuda()

1 frames

/content/captum/captum/optim/_param/image/images.py in __init__(self, size, channels, batch, init, parameterization, squash_func, decorrelation_module, decorrelate_init)
    452         self.squash_func = squash_func
    453         self.parameterization = parameterization(
--> 454             size=size, channels=channels, batch=batch, init=init
    455         )
    456 

/content/captum/captum/optim/_param/image/images.py in __init__(self, size, channels, batch, init)
    132             fourier_coeffs = random_coeffs / 50
    133         else:
--> 134             fourier_coeffs = self.torch_rfft(init) / spectrum_scale
    135 
    136         self.fourier_coeffs = nn.Parameter(fourier_coeffs)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

We don't have to solve this issue in this PR though, as it doesn't effect any of the current tutorials.

Edit:

I have resolved the issue! Though my fix makes it harder to disable color decorrelation.

* Fix `NaturalImage` device bug.

* Set `decorrelate_init` default to `False`.

* Fix `NaturalImage` size type.
@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 16, 2021

So, currently NaturalImage is setup like this:

class NaturalImage(ImageParameterization):
    def __init__(
        self,
        size: Tuple[int, int] = (224, 224),
        channels: int = 3,
        batch: int = 1,
        init: Optional[torch.Tensor] = None,
        parameterization: ImageParameterization = FFTImage,
        squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        decorrelation_module: Optional[nn.Module] = None,
        decorrelate_init: bool = True,
    ) -> None:
        super().__init__()
        self.decorrelate = decorrelation_module or ToRGB(transform="klt")
        if init is not None:
            assert init.dim() == 3 or init.dim() == 4
            if decorrelate_init:
                assert self.decorrelate is not None
                init = (
                    init.refine_names("B", "C", "H", "W")
                    if init.dim() == 4
                    else init.refine_names("C", "H", "W")
                )
                init = self.decorrelate(init, inverse=True).rename(None)
            if squash_func is None:

                def squash_func(x: torch.Tensor) -> torch.Tensor:
                    return x.clamp(0, 1)

        else:
            if squash_func is None:

                squash_func = torch.sigmoid

        self.squash_func = squash_func
        self.parameterization = parameterization(
            size=size, channels=channels, batch=batch, init=init
        )

    def forward(self) -> torch.Tensor:
        image = self.parameterization()
        if self.decorrelate is not None:
            image = self.decorrelate(image)
        image = image.rename(None)  # TODO: the world is not yet ready
        return ImageTensor(self.squash_func(image))

If you want to disable color decorrelation / recorrelation, then you either need to set NaturalImage's decorrelate variable to None or you need to set decorrelation_module to be an empty module. I think that there may be a more elegant solution to turning the color decorrelation on and off, but I haven't figured anything out yet.

We can resolve this specific issue in a future PR.

It now fits with the `Optional` type hint that it was given.
* Fix issue where the final value in a list was not selectable.
* Fix error when lists have a size of 1.
@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 22, 2021

Well I have no idea what happened here:

self = <tests.optim.core.test_optimization.TestInputOptimization testMethod=test_input_optimization>

    def test_input_optimization(self) -> None:
        if torch.__version__ <= "1.2.0":
            raise unittest.SkipTest(
                "Skipping InputOptimization test due to insufficient Torch version."
            )
        model = BasicModel_ConvNet_Optim()
        loss_fn = opt.loss.ChannelActivation(model.layer, 0)
        obj = opt.InputOptimization(model, loss_function=loss_fn)
        n_steps = 5
        history = obj.optimize(opt.optimization.n_steps(n_steps, show_progress=False))
>       self.assertTrue(history[0] > history[-1])
E       AssertionError: ImageTensor(False) is not true

My fix for _rand_select is completely unrelated, so I don't understand why this error showed up.

Edit: I can't reproduce the error at all, so maybe it's an issue with CircleCI?

Second Edit:

I think the test was failing because the previous parameters it was used were tuned for the old _rand_select which never selected the final list value. Specifically the default random scale list of (1, 0.975, 1.025, 0.95, 1.05) would have never resulted in 1.05 being chosen until I fixed _rand_select.

I've fixed the issue now!

* Tests showed that using only 5 iterations was no longer sufficient to ensure the final loss values were less than the first loss values.
@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 24, 2021

@NarineK The lines I altered in feature_ablation.py and infidelity.py are now the same as in the master branch, so there won't be any merging conflicts caused by those changes.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 24, 2021

I looked into the FFTImage size thing, and there is a clear deviation in Ludwig's first Captum iteration of FFTImage and the Lucid code.

The Captum code create the image with the shape given from the size variable, and then this same shape is used to create the scale / frequency tensor:

        coeffs_shape = (channels, size[0], size[1] // 2 + 1, 2)
        random_coeffs = torch.randn(
            coeffs_shape
        )  # names=["C", "H_f", "W_f", "complex"]
        self.fourier_coeffs = nn.Parameter(random_coeffs / 50)

        frequencies = FFTImage.rfft2d_freqs(*size)
  • The second last dimension in coeffs_shape is always 2 + 1 and never 2 + 2, so making the rfft2d_freqs always use 2 + 1 prevents a size mismatch.

https://github.com/pytorch/captum/pull/412/files#diff-d9ef468be9704729ff7c3bd65ad5b115e206b5f418039e18cb333e123428dde5R160-R166

While Lucid derives the shape of the image from the shape of the frequency tensor (that the if statement changes with an odd width value):

    batch, h, w, ch = shape
    freqs = rfft2d_freqs(h, w)
    init_val_size = (2, batch, ch) + freqs.shape

https://github.com/tensorflow/lucid/blob/master/lucid/optvis/param/spatial.py#L66-L67

Therefore, I think the special behavior for widths that have an odd value is required for the Lucid way of doing things and not the new Captum FFTImage. I think that Ludwig may have just forgotten to update rfft2d_freqs to match his changes in the init function.

@NarineK
Copy link
Contributor

NarineK commented May 25, 2021

#3 'mat2' is

Interesting, according to code snippet it looks like it worked first time calling opt.images.NaturalImage(init=torch.ones(3, 1, 1)).cuda() but failed second time ?

The issue was that by creating the ToRGB instance in the init function, and it creates a single instance that will be passed to future constructor calls as default. It's apparently a well-known trap when using default parameters:

# Example code
import torch
import torch.nn as nn

class Test1(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("test_tensor", torch.ones(5))

    def forward(self, x):
        return x * self.test_tensor

class Test2(nn.Module):
    def __init__(self, x, test1 = Test1()):
        super().__init__()
        self.test1 = test1
        x = self.test1(x)

    def forward(self, x):
        return self.test1(x)

t = Test2(torch.ones(5)).cuda()
t = Test2(torch.ones(5)).cuda() # Fails
t = Test2(torch.ones(5), Test1()).cuda()
t = Test2(torch.ones(5), Test1()).cuda() # Works

code that Lucid uses

According to lucid implementation it looks like we need to add 2 for odd cases and later remove that one additional pixel. What was the original problem that you saw with line: wadd = 2 if width % 2 == 1 else 1 ? What was failing ?

The original issue was that fourier_coeffs and spectrum_scale would have a size mistmatch if the width given to FFTImage either through the size argument, or the init tensor was an odd number.

This code here would result in the following error:

image = opt.images.FFTImage((512, 405))
out = image()
RuntimeError                              Traceback (most recent call last)

<ipython-input-3-358d0813d735> in <module>()
      1 image = opt.images.FFTImage((512, 405))
----> 2 out = image()

1 frames

/content/captum/captum/optim/_param/image/images.py in forward(self)
    183     def forward(self) -> torch.Tensor:
    184         h, w = self.size
--> 185         scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
    186         output = self.torch_irfft(scaled_spectrum)
    187         return output.refine_names("B", "C", "H", "W")

RuntimeError: The size of tensor a (203) must match the size of tensor b (204) at non-singleton dimension 3

And the same issue occurred when the init tensor size had an odd width:

image = opt.images.FFTImage(init=torch.ones(1, 3, 512, 405))
out = image()
RuntimeError                              Traceback (most recent call last)

<ipython-input-5-48fc824a1f55> in <module>()
----> 1 image = opt.images.FFTImage(init=torch.ones(1, 3, 512, 405))
      2 out = image()

/content/captum/captum/optim/_param/image/images.py in __init__(self, size, channels, batch, init)
    132             fourier_coeffs = random_coeffs / 50
    133         else:
--> 134             fourier_coeffs = self.torch_rfft(init) / spectrum_scale
    135 
    136         self.fourier_coeffs = nn.Parameter(fourier_coeffs)

RuntimeError: The size of tensor a (203) must match the size of tensor b (204) at non-singleton dimension 3

The Lucid code also talks about slicing the image tensor to resolve the issue, but in our case spectrum_scale is the bigger tensor and I'm not sure that we should be slicing it.

Thank you for the explanation, @ProGamerGov! Is there a specific reason why we chose to use buffer instead of instance field?If we use it as an instance field it shouldn't cause any problems. I meant in ToRGB as well.

# Example code
import torch
import torch.nn as nn

class Test1(nn.Module):
    def __init__(self):
        super().__init__()
        self.test_tensor = torch.ones(5)
        print(self._buffers)

    def forward(self, x):
        return x * self.test_tensor

class Test2(nn.Module):
    def __init__(self, x, test1 = Test1()):
        super().__init__()
        self.test1 = test1
        x = self.test1(x)

    def forward(self, x):
        return self.test1(x)

t = Test2(torch.ones(5)).cuda()
t = Test2(torch.ones(5)).cuda() # Works

@NarineK
Copy link
Contributor

NarineK commented May 25, 2021

@NarineK The lines I altered in feature_ablation.py and infidelity.py are now the same as in the master branch, so there won't be any merging conflicts caused by those changes.

If they are exactly the same then there won't be merge conflicts but we will have same changes in different commits. I think it is fine for now since those are minor changes but it is better to sync up with master for larger changes.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 25, 2021

@NarineK Ludwig's original Captum code had both FFTImage and ToRGB using buffers. I'm not sure why he chose to do it that way, but I can change it if you think that they should be set as self variables?

If they are exactly the same then there won't be merge conflicts but we will have same changes in different commits. I think it is fine for now since those are minor changes but it is better to sync up with master for larger changes.

Yeah, we can sync up with the master branch in separate PR. I was just trying to avoid it in this PR so that the commit history wasn't filled with all the master commits, like what happened with SK's PR.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 25, 2021

Oh, I think I understand why buffers were used now. They make it so that when a NaturalImage instance is placed onto the GPU, the buffers are also placed on the GPU.

This code works when spectrum_scale (FFTImage) and transform (ToRGB) are buffers, but results in a device error when I just define them as self variables:

obj = opt.InputOptimization(
    model,
    opt.loss.ChannelActivation(model.mixed4a, 476),
    input_param=opt.images.NaturalImage((224,224)).to(device),
)
obj.optimize()

obj.input_param().show()
/content/captum/captum/optim/_param/image/images.py in forward(self)
    180 
    181     def forward(self) -> torch.Tensor:
--> 182         scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
    183         output = self.torch_irfft(scaled_spectrum)
    184         return output.refine_names("B", "C", "H", "W")

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I guess we could override .to() & .cuda() as a possible solution? Or we could override _apply so that both .to() and .cuda() works. Though I think that buffers are the recommend solution to this problem.

@NarineK
Copy link
Contributor

NarineK commented May 25, 2021

I looked into the FFTImage size thing, and there is a clear deviation in Ludwig's first Captum iteration of FFTImage and the Lucid code.

The Captum code create the image with the shape given from the size variable, and then this same shape is used to create the scale / frequency tensor:

        coeffs_shape = (channels, size[0], size[1] // 2 + 1, 2)
        random_coeffs = torch.randn(
            coeffs_shape
        )  # names=["C", "H_f", "W_f", "complex"]
        self.fourier_coeffs = nn.Parameter(random_coeffs / 50)

        frequencies = FFTImage.rfft2d_freqs(*size)
  • The second last dimension in coeffs_shape is always 2 + 1 and never 2 + 2, so making the rfft2d_freqs always use 2 + 1 prevents a size mismatch.

https://github.com/pytorch/captum/pull/412/files#diff-d9ef468be9704729ff7c3bd65ad5b115e206b5f418039e18cb333e123428dde5R160-R166

While Lucid derives the shape of the image from the shape of the frequency tensor (that the if statement changes with an odd width value):

    batch, h, w, ch = shape
    freqs = rfft2d_freqs(h, w)
    init_val_size = (2, batch, ch) + freqs.shape

https://github.com/tensorflow/lucid/blob/master/lucid/optvis/param/spatial.py#L66-L67

Therefore, I think the special behavior for widths that have an odd value is required for the Lucid way of doing things and not the new Captum FFTImage. I think that Ludwig may have just forgotten to update rfft2d_freqs to match his changes in the init function.

Thank you for looking deeper into this:
It looks like we are dividing by 2 here:
https://github.com/pytorch/captum/pull/412/files#diff-d9ef468be9704729ff7c3bd65ad5b115e206b5f418039e18cb333e123428dde5R160-R166
and also here:
https://github.com/pytorch/captum/pull/656/files#diff-d1326a272667e088fe9934dd175f0be589edf7594ee01b7463451a5266c56b47R141
Is it necessary to do it in booth places ?

Oh, I think I understand why buffers were used now. They make it so that when a NaturalImage instance is placed onto the GPU, the buffers are also placed on the GPU.

This code works when spectrum_scale (FFTImage) and transform (ToRGB) are buffers, but results in a device error when I just define them as self variables:

obj = opt.InputOptimization(
    model,
    opt.loss.ChannelActivation(model.mixed4a, 476),
    input_param=opt.images.NaturalImage((224,224)).to(device),
)
obj.optimize()

obj.input_param().show()
/content/captum/captum/optim/_param/image/images.py in forward(self)
    180 
    181     def forward(self) -> torch.Tensor:
--> 182         scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
    183         output = self.torch_irfft(scaled_spectrum)
    184         return output.refine_names("B", "C", "H", "W")

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I guess we could override .to() & .cuda() as a possible solution? Or we could override _apply so that both .to() and .cuda() works. Though I think that buffers are the recommend solution to this problem.

Oh, I think I understand why buffers were used now. They make it so that when a NaturalImage instance is placed onto the GPU, the buffers are also placed on the GPU.

This code works when spectrum_scale (FFTImage) and transform (ToRGB) are buffers, but results in a device error when I just define them as self variables:

obj = opt.InputOptimization(
    model,
    opt.loss.ChannelActivation(model.mixed4a, 476),
    input_param=opt.images.NaturalImage((224,224)).to(device),
)
obj.optimize()

obj.input_param().show()
/content/captum/captum/optim/_param/image/images.py in forward(self)
    180 
    181     def forward(self) -> torch.Tensor:
--> 182         scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
    183         output = self.torch_irfft(scaled_spectrum)
    184         return output.refine_names("B", "C", "H", "W")

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I guess we could override .to() & .cuda() as a possible solution? Or we could override _apply so that both .to() and .cuda() works. Though I think that buffers are the recommend solution to this problem.

I think that it is tricky to put a module on the device because pytorch puts actual tensors (parameters) related to that module on the device and not the module per se. For example we cannot call .device on the module. And this makes it unclear what is on the device and what not.
https://discuss.pytorch.org/t/how-to-check-if-model-is-on-cuda/180

To be clear we can explicitly put the transform - ed tensors on the device where the actual input is before we multiply with it.

We can set:
self.transform.to(x.device)

before using it and use instance fields instead. Would that work ?

@NarineK
Copy link
Contributor

NarineK commented May 25, 2021

I looked into the FFTImage size thing, and there is a clear deviation in Ludwig's first Captum iteration of FFTImage and the Lucid code.

The Captum code create the image with the shape given from the size variable, and then this same shape is used to create the scale / frequency tensor:

        coeffs_shape = (channels, size[0], size[1] // 2 + 1, 2)
        random_coeffs = torch.randn(
            coeffs_shape
        )  # names=["C", "H_f", "W_f", "complex"]
        self.fourier_coeffs = nn.Parameter(random_coeffs / 50)

        frequencies = FFTImage.rfft2d_freqs(*size)
  • The second last dimension in coeffs_shape is always 2 + 1 and never 2 + 2, so making the rfft2d_freqs always use 2 + 1 prevents a size mismatch.

https://github.com/pytorch/captum/pull/412/files#diff-d9ef468be9704729ff7c3bd65ad5b115e206b5f418039e18cb333e123428dde5R160-R166

While Lucid derives the shape of the image from the shape of the frequency tensor (that the if statement changes with an odd width value):

    batch, h, w, ch = shape
    freqs = rfft2d_freqs(h, w)
    init_val_size = (2, batch, ch) + freqs.shape

https://github.com/tensorflow/lucid/blob/master/lucid/optvis/param/spatial.py#L66-L67

Therefore, I think the special behavior for widths that have an odd value is required for the Lucid way of doing things and not the new Captum FFTImage. I think that Ludwig may have just forgotten to update rfft2d_freqs to match his changes in the init function.

Thank you for looking into this, @ProGamerGov ! As you mentioned for coeffs_shape Ludwig used 2 + 1 and that is the inconsistency. I don't know exactly why he did it. We can try to ask and understand it better. Either way, it would be good to document those differences in the code so that we don't forget.

In terms of results, not shapes, are we getting different outcomes ?

Also there is a division by 50 magic number (this isn't related to the dimensionality) that I don't see in lucid

https://github.com/pytorch/captum/pull/412/files#diff-d9ef468be9704729ff7c3bd65ad5b115e206b5f418039e18cb333e123428dde5R164).

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 26, 2021

@NarineK Yeah, moving the tensor to the input device for every forward call would work! But would that incur any sort of performance penalty for constantly moving the tensor to the target device? I tried testing it out, and there didn't seem to be any obvious effect to the time it takes to complete 512 iterations.

Alternatively, we could just include an additional parameter to NaturalImage to disable ToRGB or set the decorrelation_module to SkipLayer.

I've been disabling the color decorrelation like this for experiments:

image = opt.images.NaturalImage(
    parameterization=opt.images.PixelImage,
    init=init_image.cpu(),
    decorrelation_module=opt.models.SkipLayer(), # Disable color decorrelation
    squash_func=lambda x: x, # Disable squash function
).to(device)

Thank you for looking into this, @ProGamerGov ! As you mentioned for coeffs_shape Ludwig used 2 + 1 and that is the inconsistency. I don't know exactly why he did it. We can try to ask and understand it better. Either way, it would be good to document those differences in the code so that we don't forget.

In terms of results, not shapes, are we getting different outcomes ?

Also there is a division by 50 magic number (this isn't related to the dimensionality) that I don't see in lucid

https://github.com/pytorch/captum/pull/412/files#diff-d9ef468be9704729ff7c3bd65ad5b115e206b5f418039e18cb333e123428dde5R164).

Ludwig hasn't been actively lately, so we might have to wait a bit to ask him why he made certain design choices. Chris may be able to help, but he is busy at the moment as well I think. Our version of FFTImage produces different results that appear to be superior to Lucid's FFTImage results. We don't divide the output by 4, normalize the fft operations, and I think there's a small difference with what is done to the scale variable before it's saved as spectrum_scale.

I think that he intended Captum's FFTImage to be the successor to the Lucid's FFTImage, but I did create a Captum version of the old FFTImage a while back and found the results to be less detailed.

@ProGamerGov
Copy link
Contributor Author

Thank you for looking deeper into this:
It looks like we are dividing by 2 here:
https://github.com/pytorch/captum/pull/412/files#diff-d9ef468be9704729ff7c3bd65ad5b115e206b5f418039e18cb333e123428dde5R160-R166
and also here:
https://github.com/pytorch/captum/pull/656/files#diff-d1326a272667e088fe9934dd175f0be589edf7594ee01b7463451a5266c56b47R141
Is it necessary to do it in booth places ?

Yes, I think it's necessary as that way we get the right matching size for 5 dimensional fft tensors:

import torch
import torch.fft
x = torch.randn(1, 3, 512, 405)
print(torch.view_as_real(torch.fft.rfftn(x, s=(512, 405))).shape)
# outputs torch.Size([1, 3, 512, 203, 2])
# 405 // 2 + 1 = 203

@NarineK
Copy link
Contributor

NarineK commented May 26, 2021

@NarineK Yeah, moving the tensor to the input device for every forward call would work! But would that incur any sort of performance penalty for constantly moving the tensor to the target device? I tried testing it out, and there didn't seem to be any obvious effect to the time it takes to complete 512 iterations.

Alternatively, we could just include an additional parameter to NaturalImage to disable ToRGB or set the decorrelation_module to SkipLayer.

I've been disabling the color decorrelation like this for experiments:

image = opt.images.NaturalImage(
    parameterization=opt.images.PixelImage,
    init=init_image.cpu(),
    decorrelation_module=opt.models.SkipLayer(), # Disable color decorrelation
    squash_func=lambda x: x, # Disable squash function
).to(device)

Thank you for looking into this, @ProGamerGov ! As you mentioned for coeffs_shape Ludwig used 2 + 1 and that is the inconsistency. I don't know exactly why he did it. We can try to ask and understand it better. Either way, it would be good to document those differences in the code so that we don't forget.
In terms of results, not shapes, are we getting different outcomes ?
Also there is a division by 50 magic number (this isn't related to the dimensionality) that I don't see in lucid
https://github.com/pytorch/captum/pull/412/files#diff-d9ef468be9704729ff7c3bd65ad5b115e206b5f418039e18cb333e123428dde5R164).

Ludwig hasn't been actively lately, so we might have to wait a bit to ask him why he made certain design choices. Chris may be able to help, but he is busy at the moment as well I think. Our version of FFTImage produces different results that appear to be superior to Lucid's FFTImage results. We don't divide the output by 4, normalize the fft operations, and I think there's a small difference with what is done to the scale variable before it's saved as spectrum_scale.

I think that he intended Captum's FFTImage to be the successor to the Lucid's FFTImage, but I did create a Captum version of the old FFTImage a while back and found the results to be less detailed.

I don't know or think that there will be performance issues. It is simply linking the device. I don't think that .to(device) is costly. It is most probably a simple linking.

Alternatively, we could also pass device as an optional argument to the constructor. The user needs to specify explicitly, but we can specifically describe that the device is used for the transformation tensor.

I think that users need to understand SkipLayer and additional logic related to input arguments whereas passing device is much easier concept to grasp.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented May 26, 2021

@NarineK I just realized that there may be a better way that we could solve this with minimal change to NaturalImage. I think that we could just do self.decorrelate = decorrelation_module.cpu() to ensure that the decorrelation_module is always on the CPU for the start, and then we can still create the ToRGB instance in the __init__ function signature . The init tensor is already expected to be on the CPU, if it is not set to None. So, this might be a more elegant solution?

This way we would keep the .to() and .cuda() behavior of the buffers, and it means that we don't have to change the code in ToRGB. The original intended behavior of the decorrelation_module where it can be set to None to disable color decorrelation would also be preserved.

    def __init__(
        decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"),
    ) -> None:
        super().__init__()

        # Place decorrelation_module on cpu if it's not None
        self.decorrelate = (
            decorrelation_module.cpu() if decorrelation_module is not None else None
        )

@ProGamerGov
Copy link
Contributor Author

@NarineK I've implemented the decorrelation_module.cpu() fix and resolved the Conda test issue. I think that all of the issues raised in this PR have been resolved now, and thus it may be ready for merging?

@NarineK
Copy link
Contributor

NarineK commented May 27, 2021

@NarineK I've implemented the decorrelation_module.cpu() fix and resolved the Conda test issue. I think that all of the issues raised in this PR have been resolved now, and thus it may be ready for merging?

This means that the tensor coming out of decorrelation_module will always be on cpu even if the user explicitly puts it on the gpu in the decorrelation_module - we will override it with cpu. Since this is not critical we can leave it as is. I haven't reviewed the rest of PR yet. Let me have a quick look.

decorrelation_module.cpu() if decorrelation_module is not None else None
)
if init is not None:
assert not init.is_cuda
Copy link
Contributor

@NarineK NarineK Jun 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enforcing the tensor to be on CPU without warning why, will make the user to think that there is a bug in the code.
We should explicitly put is on the CPU. I assume this is because of decorrelation_module.cpu()

I think that it would be good to print a warning message about why we put decorrelation_module on the CPU and explain it in the documentation.

Perhaps also adding a TODO so that we can make the module more flexible and don't have to move to CPU device.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jun 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK The init tensor had to be put on the CPU because FFTImage expected it to be on the CPU, but I just fixed it so that spectrum_scale is now placed on to the init tensor's device.

In order to avoid any future issues with the decorrelation module being created in the function signature, I added a deepcopy line that only runs when it's a ToRGB instance.

I also removed the device assertion and made ToRGB place the transform on the input tensor's device as well, so that it becomes a no-op when the buffer is placed on the target device.

The device issues should now be resolved I think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the deepcopy line as it was redundant.

Copy link
Contributor

@NarineK NarineK Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks cleaner now.

Here squash_func is not assigned, right? We could also represent it as one-line lambda function. I wonder why we use different squash_funcs depending on whether init is provided or not.

if squash_func is None: 
 def squash_func(x: torch.Tensor) -> torch.Tensor: 
   return x.clamp(0, 1)

Also, do you know what is it meant by the comment # TODO: the world is not yet ready at line 460 ?

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK The different squash function for when init tensors are provided was something I found to produce better results with images. Lambda functions were used previously, but SK thought inner functions were better.

The TODO line is because named dimensions are not fully supported by PyTorch yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, why are inner functions better in this case ?
rename is used in line 440 too. I we can leave it but ideally we should use it for the PT versions that support it.

"""

def forward(self, x: torch.Tensor) -> torch.Tensor:
def __init__(self, *args, **kwargs) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are the args and kwargs passed if they are not used ? Are there any real scenarios that we need it ? Is this because we replace ReLU with SkipLayer ?
I could imagine, for example, that x is a tuple of tensors and we would need to return that tuple.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change because users may have models using activ(inplace=False) or activ(False) where activ = torch.nn.ReLU. This is same way that torch.nn.Identity works: https://pytorch.org/docs/stable/generated/torch.nn.Identity.html

I'll add the type hint for tuples of tensors.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jun 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've improved the documentation, added the type hints, and also provided a link to the nn.Identity class!

def __init__(self, *args, **kwargs) -> None:
super().__init__()

def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are the args and kwargs passed if they are not used ? We should document it because it is unclear why this is needed.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jun 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change for situations like replacing the decorrelation module as it specifies an additional argument in the forward pass. I'll add some documentation explaining why I added args and kwargs. Without this change, SkipLayer is the exact same as torch.nn.Identity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this will work for ReLU with one input tensor and specific cases like ReLUs but it will have issues with custom modules that take tuples of tensors and modules such as nn.Linear, etc. I think that perhaps it would be good to give more specific name to it or describe that this is a skip layer that assumes that the inputs and outputs have to have the same shape and that it is the first tensor that is returned as an output.

@NarineK NarineK self-requested a review June 7, 2021 00:24
Copy link
Contributor

@NarineK NarineK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you for addressing the comments, @ProGamerGov!
I left two minor comments.

@NarineK NarineK merged commit 46e16e4 into meta-pytorch:optim-wip Jun 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants