Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 61 additions & 20 deletions captum/optim/_param/image/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class FFTImage(ImageParameterization):
Parameterize an image using inverse real 2D FFT
"""

__constants__ = ["size", "_supports_is_scripting"]

def __init__(
self,
size: Tuple[int, int] = None,
Expand Down Expand Up @@ -197,6 +199,9 @@ def __init__(
self.register_buffer("spectrum_scale", spectrum_scale)
self.fourier_coeffs = nn.Parameter(fourier_coeffs)

# Check & store whether or not we can use torch.jit.is_scripting()
self._supports_is_scripting = torch.__version__ >= "1.6.0"

def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
"""
Computes 2D spectrum frequencies.
Expand All @@ -214,6 +219,12 @@ def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
fx = self.torch_fftfreq(width)[: width // 2 + 1]
return torch.sqrt((fx * fx) + (fy * fy))

@torch.jit.export
def torch_irfftn(self, x: torch.Tensor) -> torch.Tensor:
if x.dtype != torch.complex64:
x = torch.view_as_complex(x)
return torch.fft.irfftn(x, s=self.size) # type: ignore

def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
"""
Support older versions of PyTorch. This function ensures that the same FFT
Expand All @@ -226,26 +237,24 @@ def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
"""

if TORCH_VERSION >= "1.7.0":
import torch.fft
if TORCH_VERSION < "1.8.0":
global torch
import torch.fft

def torch_rfft(x: torch.Tensor) -> torch.Tensor:
return torch.view_as_real(torch.fft.rfftn(x, s=self.size))

def torch_irfft(x: torch.Tensor) -> torch.Tensor:
if type(x) is not torch.complex64:
x = torch.view_as_complex(x)
return torch.fft.irfftn(x, s=self.size) # type: ignore
torch_irfftn = self.torch_irfftn

def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
return torch.fft.fftfreq(v, d)

else:
import torch

def torch_rfft(x: torch.Tensor) -> torch.Tensor:
return torch.rfft(x, signal_ndim=2)

def torch_irfft(x: torch.Tensor) -> torch.Tensor:
def torch_irfftn(x: torch.Tensor) -> torch.Tensor:
return torch.irfft(x, signal_ndim=2)[
:, :, : self.size[0], : self.size[1]
]
Expand All @@ -258,7 +267,7 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
results[s:] = torch.arange(-(v // 2), 0)
return results * (1.0 / (v * d))

return torch_rfft, torch_irfft, torch_fftfreq
return torch_rfft, torch_irfftn, torch_fftfreq

def forward(self) -> torch.Tensor:
"""
Expand All @@ -268,6 +277,9 @@ def forward(self) -> torch.Tensor:

scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
output = self.torch_irfft(scaled_spectrum)
if self._supports_is_scripting:
if torch.jit.is_scripting():
return output
return output.refine_names("B", "C", "H", "W")


Expand All @@ -276,6 +288,8 @@ class PixelImage(ImageParameterization):
Parameterize a simple pixel image tensor that requires no additional transforms.
"""

__constants__ = ["_supports_is_scripting"]

def __init__(
self,
size: Tuple[int, int] = None,
Expand Down Expand Up @@ -309,7 +323,13 @@ def __init__(
f"input has {init.shape[1]} channels."
self.image = nn.Parameter(init)

# Check & store whether or not we can use torch.jit.is_scripting()
self._supports_is_scripting = torch.__version__ >= "1.6.0"

def forward(self) -> torch.Tensor:
if self._supports_is_scripting:
if torch.jit.is_scripting():
return self.image
return self.image.refine_names("B", "C", "H", "W")


Expand Down Expand Up @@ -600,7 +620,7 @@ def __init__(
nn.Parameter tensor, or stacking init images.
Default: 1
parameterization (ImageParameterization, optional): An image
parameterization class.
parameterization class, or instance of an image parameterization class.
Default: FFTImage
squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash
function to use after color recorrelation. A funtion or lambda function.
Expand All @@ -612,8 +632,14 @@ def __init__(
Default: True
"""
super().__init__()
if not isinstance(parameterization, ImageParameterization):
# Verify uninitialized class is correct type
assert issubclass(parameterization, ImageParameterization)
else:
assert isinstance(parameterization, ImageParameterization)

self.decorrelate = decorrelation_module
if init is not None:
if init is not None and not isinstance(parameterization, ImageParameterization):
assert init.dim() == 3 or init.dim() == 4
if decorrelate_init and self.decorrelate is not None:
init = (
Expand All @@ -622,27 +648,42 @@ def __init__(
else init.refine_names("C", "H", "W")
)
init = self.decorrelate(init, inverse=True).rename(None)

if squash_func is None:
squash_func = self._clamp_image

def squash_func(x: torch.Tensor) -> torch.Tensor:
return x.clamp(0, 1)
self.squash_func = torch.sigmoid if squash_func is None else squash_func
if not isinstance(parameterization, ImageParameterization):
parameterization = parameterization(
size=size, channels=channels, batch=batch, init=init
)
self.parameterization = parameterization

else:
if squash_func is None:
@torch.jit.export
def _clamp_image(self, x: torch.Tensor) -> torch.Tensor:
"""JIT supported squash function."""
return x.clamp(0, 1)

squash_func = torch.sigmoid
@torch.jit.ignore
def _to_image_tensor(self, x: torch.Tensor) -> torch.Tensor:
"""
Wrap ImageTensor in torch.jit.ignore for JIT support.

self.squash_func = squash_func
self.parameterization = parameterization(
size=size, channels=channels, batch=batch, init=init
)
Args:

x (torch.tensor): An input tensor.

Returns:
x (ImageTensor): An instance of ImageTensor with the input tensor.
"""
return ImageTensor(x)

def forward(self) -> torch.Tensor:
image = self.parameterization()
if self.decorrelate is not None:
image = self.decorrelate(image)
image = image.rename(None) # TODO: the world is not yet ready
return ImageTensor(self.squash_func(image))
return self._to_image_tensor(self.squash_func(image))


__all__ = [
Expand Down
Loading