diff --git a/captum/optim/__init__.py b/captum/optim/__init__.py old mode 100755 new mode 100644 diff --git a/captum/optim/_core/optimization.py b/captum/optim/_core/optimization.py index 27c5bf3162..c251dfc8ec 100644 --- a/captum/optim/_core/optimization.py +++ b/captum/optim/_core/optimization.py @@ -46,6 +46,7 @@ def __init__( ) -> None: r""" Args: + model (nn.Module): The reference to PyTorch model instance. input_param (nn.Module, optional): A module that generates an input, consumed by the model. @@ -71,6 +72,7 @@ def __init__( def loss(self) -> torch.Tensor: r"""Compute loss value for current iteration. + Returns: *tensor* representing **loss**: - **loss** (*tensor*): @@ -115,18 +117,26 @@ def optimize( lr: float = 0.025, ) -> torch.Tensor: r"""Optimize input based on loss function and objectives. + Args: + stop_criteria (StopCriteria, optional): A function that is called every iteration and returns a bool that determines whether to stop the optimization. See captum.optim.typing.StopCriteria for details. optimizer (Optimizer, optional): An torch.optim.Optimizer used to optimize the input based on the loss function. + loss_summarize_fn (Callable, optional): The function to use for summarizing + tensor outputs from loss functions. + Default: default_loss_summarize + lr: (float, optional): If no optimizer is given, then lr is used as the + learning rate for the Adam optimizer. + Default: 0.025 + Returns: - *list* of *np.arrays* representing the **history**: - - **history** (*list*): - A list of loss values per iteration. - Length of the list corresponds to the number of iterations + history (torch.Tensor): A stack of loss values per iteration. The size + of the dimension on which loss values are stacked corresponds to + the number of iterations. """ stop_criteria = stop_criteria or n_steps(512) optimizer = optimizer or optim.Adam(self.parameters(), lr=lr) @@ -150,10 +160,12 @@ def optimize( def n_steps(n: int, show_progress: bool = True) -> StopCriteria: """StopCriteria generator that uses number of steps as a stop criteria. + Args: n (int): Number of steps to run optimization. show_progress (bool, optional): Whether or not to show progress bar. Default: True + Returns: *StopCriteria* callable """ diff --git a/captum/optim/_core/output_hook.py b/captum/optim/_core/output_hook.py old mode 100755 new mode 100644 index 4bbf5c0fa3..6cfbc4ff2e --- a/captum/optim/_core/output_hook.py +++ b/captum/optim/_core/output_hook.py @@ -8,12 +8,13 @@ from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType -class ModuleReuseException(Exception): - pass - - class ModuleOutputsHook: def __init__(self, target_modules: Iterable[nn.Module]) -> None: + """ + Args: + + target_modules (Iterable of nn.Module): A list of nn.Module targets. + """ self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None) self.hooks = [ module.register_forward_hook(self._forward_hook()) @@ -21,6 +22,9 @@ def __init__(self, target_modules: Iterable[nn.Module]) -> None: ] def _reset_outputs(self) -> None: + """ + Delete captured activations. + """ self.outputs = dict.fromkeys(self.outputs.keys(), None) @property @@ -28,6 +32,13 @@ def is_ready(self) -> bool: return all(value is not None for value in self.outputs.values()) def _forward_hook(self) -> Callable: + """ + Return the forward_hook function. + + Returns: + forward_hook (Callable): The forward_hook function. + """ + def forward_hook( module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor ) -> None: @@ -49,6 +60,12 @@ def forward_hook( return forward_hook def consume_outputs(self) -> ModuleOutputMapping: + """ + Collect target activations and return them. + + Returns: + outputs (ModuleOutputMapping): The captured outputs. + """ if not self.is_ready: warn( "Consume captured outputs, but not all requested target outputs " @@ -63,11 +80,16 @@ def targets(self) -> Iterable[nn.Module]: return self.outputs.keys() def remove_hooks(self) -> None: + """ + Remove hooks. + """ for hook in self.hooks: hook.remove() def __del__(self) -> None: - # print(f"DEL HOOKS!: {list(self.outputs.keys())}") + """ + Ensure that using 'del' properly deletes hooks. + """ self.remove_hooks() @@ -77,16 +99,34 @@ class ActivationFetcher: """ def __init__(self, model: nn.Module, targets: Iterable[nn.Module]) -> None: + """ + Args: + + model (nn.Module): The reference to PyTorch model instance. + targets (nn.Module or list of nn.Module): The target layers to + collect activations from. + """ super(ActivationFetcher, self).__init__() self.model = model self.layers = ModuleOutputsHook(targets) def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping: + """ + Args: + + input_t (tensor or tuple of tensors, optional): The input to use + with the specified model. + + Returns: + activations_dict: An dict containing the collected activations. The keys + for the returned dictionary are the target layers. + """ + try: with warnings.catch_warnings(): warnings.simplefilter("ignore") self.model(input_t) - activations = self.layers.consume_outputs() + activations_dict = self.layers.consume_outputs() finally: self.layers.remove_hooks() - return activations + return activations_dict diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py old mode 100755 new mode 100644 index b0852a512c..cf4b01da0d --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -27,6 +27,15 @@ def __new__( *args, **kwargs, ) -> torch.Tensor: + """ + Args: + + x (list or np.ndarray or torch.Tensor): A list, NumPy array, or PyTorch + tensor to create an `ImageTensor` from. + + Returns: + x (ImageTensor): An `ImageTensor` instance. + """ if isinstance(x, torch.Tensor) and x.is_cuda: x.show = MethodType(cls.show, x) x.export = MethodType(cls.export, x) @@ -36,6 +45,20 @@ def __new__( @classmethod def open(cls, path: str, scale: float = 255.0, mode: str = "RGB") -> "ImageTensor": + """ + Load an image file from a URL or local filepath directly into an `ImageTensor`. + + Args: + + path (str): A URL or filepath to an image. + scale (float, optional): The image scale to use. + Default: 255.0 + mode (str, optional): The image loading mode to use. + Default: "RGB" + + Returns: + x (ImageTensor): An `ImageTensor` instance. + """ if path.startswith("https://") or path.startswith("http://"): response = requests.get(path, stream=True) img = Image.open(response.raw) @@ -73,9 +96,31 @@ def __torch_function__( def show( self, figsize: Optional[Tuple[int, int]] = None, scale: float = 255.0 ) -> None: + """ + Display an `ImageTensor`. + + Args: + + figsize (Tuple[int, int], optional): height & width to use + for displaying the `ImageTensor` figure. + scale (float, optional): Value to multiply the `ImageTensor` by so that + it's value range is [0-255] for display. + Default: 255.0 + """ show(self, figsize=figsize, scale=scale) def export(self, filename: str, scale: float = 255.0) -> None: + """ + Save an `ImageTensor` as an image file. + + Args: + + filename (str): The filename to use when saving the `ImageTensor` as an + image file. + scale (float, optional): Value to multiply the `ImageTensor` by so that + it's value range is [0-255] for saving. + Default: 255.0 + """ save_tensor_as_image(self, filename=filename, scale=scale) @@ -89,7 +134,9 @@ class ImageParameterization(InputParameterization): class FFTImage(ImageParameterization): - """Parameterize an image using inverse real 2D FFT""" + """ + Parameterize an image using inverse real 2D FFT + """ def __init__( self, @@ -98,6 +145,20 @@ def __init__( batch: int = 1, init: Optional[torch.Tensor] = None, ) -> None: + """ + Args: + + size (Tuple[int, int]): The height & width dimensions to use for the + parameterized output image tensor. + channels (int, optional): The number of channels to use for each image. + Default: 3 + batch (int, optional): The number of images to stack along the batch + dimension. + Default: 1 + init (torch.tensor, optional): Optionally specify a tensor to + use instead of creating one. + Default: None + """ super().__init__() if init is None: assert len(size) == 2 @@ -137,13 +198,33 @@ def __init__( self.fourier_coeffs = nn.Parameter(fourier_coeffs) def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor: - """Computes 2D spectrum frequencies.""" + """ + Computes 2D spectrum frequencies. + + Args: + + height (int): The h dimension of the 2d frequency scale. + width (int): The w dimension of the 2d frequency scale. + + Returns: + **tensor** (tensor): A 2d frequency scale tensor. + """ + fy = self.torch_fftfreq(height)[:, None] fx = self.torch_fftfreq(width)[: width // 2 + 1] return torch.sqrt((fx * fx) + (fy * fy)) def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]: - """Support older versions of PyTorch""" + """ + Support older versions of PyTorch. This function ensures that the same FFT + operations are carried regardless of whether your PyTorch version has the + torch.fft update. + + Returns: + fft functions (tuple of Callable): A list of FFT functions + to use for irfft, rfft, and fftfreq operations. + """ + if TORCH_VERSION >= "1.7.0": import torch.fft @@ -180,12 +261,21 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor: return torch_rfft, torch_irfft, torch_fftfreq def forward(self) -> torch.Tensor: + """ + Returns: + **output** (torch.tensor): A spatially recorrelated tensor. + """ + scaled_spectrum = self.fourier_coeffs * self.spectrum_scale output = self.torch_irfft(scaled_spectrum) return output.refine_names("B", "C", "H", "W") class PixelImage(ImageParameterization): + """ + Parameterize a simple pixel image tensor that requires no additional transforms. + """ + def __init__( self, size: Tuple[int, int] = None, @@ -193,6 +283,20 @@ def __init__( batch: int = 1, init: Optional[torch.Tensor] = None, ) -> None: + """ + Args: + + size (Tuple[int, int]): The height & width dimensions to use for the + parameterized output image tensor. + channels (int, optional): The number of channels to use for each image. + Default: 3 + batch (int, optional): The number of images to stack along the batch + dimension. + Default: 1 + init (torch.tensor, optional): Optionally specify a tensor to + use instead of creating one. + Default: None + """ super().__init__() if init is None: assert size is not None and channels is not None and batch is not None @@ -212,6 +316,7 @@ def forward(self) -> torch.Tensor: class LaplacianImage(ImageParameterization): """ TODO: Fix divison by 6 in setup_input when init is not None. + Parameterize an image tensor with a laplacian pyramid. """ def __init__( @@ -221,11 +326,25 @@ def __init__( batch: int = 1, init: Optional[torch.Tensor] = None, ) -> None: + """ + Args: + + size (Tuple[int, int]): The height & width dimensions to use for the + parameterized output image tensor. + channels (int, optional): The number of channels to use for each image. + Default: 3 + batch (int, optional): The number of images to stack along the batch + dimension. + Default: 1 + init (torch.tensor, optional): Optionally specify a tensor to + use instead of creating one. + Default: None + """ super().__init__() power = 0.1 if init is None: - tensor_params, self.scaler = self.setup_input(size, channels, power, init) + tensor_params, self.scaler = self._setup_input(size, channels, power, init) self.tensor_params = torch.nn.ModuleList( [deepcopy(tensor_params) for b in range(batch)] @@ -234,13 +353,13 @@ def __init__( init = init.unsqueeze(0) if init.dim() == 3 else init P = [] for b in range(init.size(0)): - tensor_params, self.scaler = self.setup_input( + tensor_params, self.scaler = self._setup_input( size, channels, power, init[b].unsqueeze(0) ) P.append(tensor_params) self.tensor_params = torch.nn.ModuleList(P) - def setup_input( + def _setup_input( self, size: Tuple[int, int], channels: int, @@ -264,16 +383,26 @@ def setup_input( tensor_params = torch.nn.ParameterList(tensor_params) return tensor_params, scaler - def create_tensor(self, params_list: torch.nn.ParameterList) -> torch.Tensor: - A = [] + def _create_tensor(self, params_list: torch.nn.ParameterList) -> torch.Tensor: + """ + Resize tensor parameters to the target size. + + Args: + + params_list (torch.nn.ParameterList): List of tensors to resize. + + Returns: + **tensor** (torch.Tensor): The sum of all tensor parameters. + """ + A: List[torch.Tensor] = [] for xi, upsamplei in zip(params_list, self.scaler): A.append(upsamplei(xi)) return torch.sum(torch.cat(A), 0) + 0.5 def forward(self) -> torch.Tensor: - A = [] + A: List[torch.Tensor] = [] for params_list in self.tensor_params: - tensor = self.create_tensor(params_list) + tensor = self._create_tensor(params_list) A.append(tensor) return torch.stack(A).refine_names("B", "C", "H", "W") @@ -297,6 +426,17 @@ def __init__( parameterization: ImageParameterization = None, offset: Union[int, Tuple[int], Tuple[Tuple[int]], None] = None, ) -> None: + """ + Args: + + shapes (list of int or list of list of ints): The shapes of the shared + tensors to use for creating the nn.Parameter tensors. + parameterization (ImageParameterization): An image parameterization + instance. + offset (int or list of int or list of list of ints , optional): The offsets + to use for the shared tensors. + Default: None + """ super().__init__() assert shapes is not None A = [] @@ -308,9 +448,21 @@ def __init__( A.append(torch.nn.Parameter(torch.randn([batch, channels, height, width]))) self.shared_init = torch.nn.ParameterList(A) self.parameterization = parameterization - self.offset = self.get_offset(offset, len(A)) if offset is not None else None + self.offset = self._get_offset(offset, len(A)) if offset is not None else None + + def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]: + """ + Given offset values, return a list of offsets for _apply_offset to use. + + Args: + + offset (int or list of int or list of list of ints , optional): The offsets + to use for the shared tensors. + n (int): The number of tensors needing offset values. - def get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]: + Returns: + **offset** (list of list of int): A list of offset values. + """ if type(offset) is tuple or type(offset) is list: if type(offset[0]) is tuple or type(offset[0]) is list: assert len(offset) == n and all(len(t) == 4 for t in offset) @@ -323,8 +475,19 @@ def get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]: assert all([all([type(o) is int for o in v]) for v in offset]) return offset - def apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: - A = [] + def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Apply list of offsets to list of tensors. + + Args: + + x_list (list of torch.Tensor): list of tensors to offset. + + Returns: + **A** (list of torch.Tensor): list of offset tensors. + """ + + A: List[torch.Tensor] = [] for x, offset in zip(x_list, self.offset): assert x.dim() == 4 size = list(x.size()) @@ -345,13 +508,23 @@ def apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: A.append(x) return A - def interpolate_tensor( + def _interpolate_tensor( self, x: torch.Tensor, batch: int, channels: int, height: int, width: int ) -> torch.Tensor: """ - Linear interpolation for 4D, 5D, and 6D tensors. - If the batch dimension needs to be resized, - we move it's location temporarily for F.interpolate. + Linear interpolation for 4D, 5D, and 6D tensors. If the batch dimension needs + to be resized, we move it's location temporarily for F.interpolate. + + Args: + + x (torch.Tensor): The tensor to resize. + batch (int): The batch size to resize the tensor to. + channels (int): The channel size to resize the tensor to. + height (int): The height to resize the tensor to. + width (int): The width to resize the tensor to. + + Returns: + **tensor** (torch.Tensor): A resized tensor. """ if x.size(1) == channels: @@ -376,7 +549,7 @@ def interpolate_tensor( def forward(self) -> torch.Tensor: image = self.parameterization() x = [ - self.interpolate_tensor( + self._interpolate_tensor( shared_tensor, image.size(0), image.size(1), @@ -386,7 +559,7 @@ def forward(self) -> torch.Tensor: for shared_tensor in self.shared_init ] if self.offset is not None: - x = self.apply_offset(x) + x = self._apply_offset(x) return (image + sum(x)).refine_names("B", "C", "H", "W") @@ -401,21 +574,6 @@ class NaturalImage(ImageParameterization): If a model requires a normalization step, such as normalizing imagenet RGB values, or rescaling to [0,255], it can perform those steps with the provided transforms or inside its computation. - - Arguments: - size (Tuple[int, int]): The height and width to use for the nn.Parameter image - tensor. - channels (int): The number of channels to use when creating the - nn.Parameter tensor. - batch (int): The number of channels to use when creating the - nn.Parameter tensor, or stacking init images. - parameterization (ImageParameterization, optional): An image parameterization - class. - squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash - function to use after color recorrelation. A funtion or lambda function. - decorrelation_module (nn.Module, optional): A ToRGB instance. - decorrelate_init (bool, optional): Whether or not to apply color decorrelation - to the init tensor input. """ def __init__( @@ -429,6 +587,30 @@ def __init__( decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"), decorrelate_init: bool = True, ) -> None: + """ + Args: + + size (Tuple[int, int], optional): The height and width to use for the + nn.Parameter image tensor. + Default: (224, 224) + channels (int, optional): The number of channels to use when creating the + nn.Parameter tensor. + Default: 3 + batch (int, optional): The number of channels to use when creating the + nn.Parameter tensor, or stacking init images. + Default: 1 + parameterization (ImageParameterization, optional): An image + parameterization class. + Default: FFTImage + squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash + function to use after color recorrelation. A funtion or lambda function. + Default: None + decorrelation_module (nn.Module, optional): A ToRGB instance. + Default: ToRGB + decorrelate_init (bool, optional): Whether or not to apply color + decorrelation to the init tensor input. + Default: True + """ super().__init__() self.decorrelate = decorrelation_module if init is not None: diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index ba3c146e40..93df78243e 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -13,15 +13,31 @@ class BlendAlpha(nn.Module): r"""Blends a 4 channel input parameterization into an RGB image. - You can specify a fixed background, or a random one will be used by default. """ def __init__(self, background: Optional[torch.Tensor] = None) -> None: + """ + Args: + + background (tensor, optional): An NCHW image tensor to be used as the + Alpha channel's background. + Default: None + """ super().__init__() self.background = background def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Blend the Alpha channel into the RGB channels. + + Args: + + x (torch.Tensor): RGBA image tensor to blend into an RGB image tensor. + + Returns: + **blended** (torch.Tensor): RGB image tensor. + """ assert x.dim() == 4 assert x.size(1) == 4 rgb, alpha = x[:, :3, ...], x[:, 3:4, ...] @@ -36,6 +52,16 @@ class IgnoreAlpha(nn.Module): r"""Ignores a 4th channel""" def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Ignore the alpha channel. + + Args: + + x (torch.Tensor): RGBA image tensor. + + Returns: + **rgb** (torch.Tensor): RGB image tensor without the alpha channel. + """ assert x.dim() == 4 assert x.size(1) == 4 rgb = x[:, :3, ...] @@ -52,16 +78,17 @@ class ToRGB(nn.Module): [0] Y. Ohta, T. Kanade, and T. Sakai, "Color information for region segmentation," Computer Graphics and Image Processing, vol. 13, no. 3, pp. 222–241, 1980 https://www.sciencedirect.com/science/article/pii/0146664X80900477 - - Arguments: - transform (str or tensor): Either a string for one of the precalculated - transform matrices, or a 3x3 matrix for the 3 RGB channels of input - tensors. """ @staticmethod def klt_transform() -> torch.Tensor: - """Karhunen-Loève transform (KLT) measured on ImageNet""" + """ + Karhunen-Loève transform (KLT) measured on ImageNet + + Returns: + **transform** (torch.Tensor): A Karhunen-Loève transform (KLT) measured on + the ImageNet dataset. + """ KLT = [[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]] transform = torch.Tensor(KLT).float() transform = transform / torch.max(torch.norm(transform, dim=0)) @@ -69,6 +96,11 @@ def klt_transform() -> torch.Tensor: @staticmethod def i1i2i3_transform() -> torch.Tensor: + """ + Returns: + **transform** (torch.Tensor): An approximation of natural colors transform + (i1i2i3). + """ i1i2i3_matrix = [ [1 / 3, 1 / 3, 1 / 3], [1 / 2, 0, -1 / 2], @@ -77,6 +109,13 @@ def i1i2i3_transform() -> torch.Tensor: return torch.Tensor(i1i2i3_matrix) def __init__(self, transform: Union[str, torch.Tensor] = "klt") -> None: + """ + Args: + + transform (str or tensor): Either a string for one of the precalculated + transform matrices, or a 3x3 matrix for the 3 RGB channels of input + tensors. + """ super().__init__() assert isinstance(transform, str) or torch.is_tensor(transform) if torch.is_tensor(transform): @@ -93,6 +132,18 @@ def __init__(self, transform: Union[str, torch.Tensor] = "klt") -> None: ) def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor: + """ + Args: + + x (torch.tensor): A CHW or NCHW RGB or RGBA image tensor. + inverse (bool, optional): Whether to recorrelate or decorrelate colors. + Default: False. + + Returns: + chw (torch.tensor): A tensor with it's colors recorrelated or + decorrelated. + """ + assert x.dim() == 3 or x.dim() == 4 # alpha channel is taken off... @@ -128,15 +179,6 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor: class CenterCrop(torch.nn.Module): """ Center crop a specified amount from a tensor. - Arguments: - size (int, sequence, int): Number of pixels to center crop away. - pixels_from_edges (bool, optional): Whether to treat crop size - values as the number of pixels from the tensor's edge, or an - exact shape in the center. - offset_left (bool, optional): If the cropped away sides are not - equal in size, offset center by +1 to the left and/or top. - Default is set to False. This parameter is only valid when - pixels_from_edges is False. """ def __init__( @@ -145,6 +187,18 @@ def __init__( pixels_from_edges: bool = False, offset_left: bool = False, ) -> None: + """ + Args: + + size (int, sequence, int): Number of pixels to center crop away. + pixels_from_edges (bool, optional): Whether to treat crop size + values as the number of pixels from the tensor's edge, or an + exact shape in the center. + offset_left (bool, optional): If the cropped away sides are not + equal in size, offset center by +1 to the left and/or top. + This parameter is only valid when `pixels_from_edges` is False. + Default: False + """ super().__init__() self.crop_vals = size self.pixels_from_edges = pixels_from_edges @@ -153,10 +207,12 @@ def __init__( def forward(self, input: torch.Tensor) -> torch.Tensor: """ Center crop an input. - Arguments: + + Args: input (torch.Tensor): Input to center crop. + Returns: - tensor (torch.Tensor): A center cropped tensor. + **tensor** (torch.Tensor): A center cropped *tensor*. """ return center_crop( @@ -172,18 +228,22 @@ def center_crop( ) -> torch.Tensor: """ Center crop a specified amount from a tensor. - Arguments: + + Args: + input (tensor): A CHW or NCHW image tensor to center crop. size (int, sequence, int): Number of pixels to center crop away. pixels_from_edges (bool, optional): Whether to treat crop size values as the number of pixels from the tensor's edge, or an exact shape in the center. + Default: False offset_left (bool, optional): If the cropped away sides are not equal in size, offset center by +1 to the left and/or top. - Default is set to False. This parameter is only valid when - pixels_from_edges is False. + This parameter is only valid when `pixels_from_edges` is False. + Default: False + Returns: - *tensor*: A center cropped tensor. + **tensor**: A center cropped *tensor*. """ assert input.dim() == 3 or input.dim() == 4 @@ -218,6 +278,13 @@ def _rand_select( ) -> Union[int, float, torch.Tensor]: """ Randomly return a single value from the provided tuple, list, or tensor. + + Args: + + transform_values (sequence): A sequence of values to randomly select from. + + Returns: + **value**: A single value from the specified sequence. """ n = torch.randint(low=0, high=len(transform_values), size=[1]).item() return transform_values[n] @@ -226,11 +293,14 @@ def _rand_select( class RandomScale(nn.Module): """ Apply random rescaling on a NCHW tensor. - Arguments: - scale (float, sequence): Tuple of rescaling values to randomly select from. """ def __init__(self, scale: NumSeqOrTensorType) -> None: + """ + Args: + + scale (float, sequence): Tuple of rescaling values to randomly select from. + """ super().__init__() self.scale = scale @@ -258,6 +328,16 @@ def scale_tensor( return x def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Randomly scale / zoom in or out of a tensor. + + Args: + + input (torch.Tensor): Input to randomly scale. + + Returns: + **tensor** (torch.Tensor): Scaled *tensor*. + """ scale = _rand_select(self.scale) return self.scale_tensor(input, scale=scale) @@ -265,11 +345,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class RandomSpatialJitter(torch.nn.Module): """ Apply random spatial translations on a NCHW tensor. - Arguments: - translate (int): """ def __init__(self, translate: int) -> None: + """ + Args: + + translate (int): The max horizontal and vertical translation to use. + """ super().__init__() self.pad_range = 2 * translate self.pad = nn.ReflectionPad2d(translate) @@ -287,6 +370,16 @@ def translate_tensor(self, x: torch.Tensor, insets: torch.Tensor) -> torch.Tenso return cropped def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Randomly translate an input tensor's height and width dimensions. + + Args: + + input (torch.Tensor): Input to randomly translate. + + Returns: + **tensor** (torch.Tensor): A randomly translated *tensor*. + """ insets = torch.randint(high=self.pad_range, size=(2,)) return self.translate_tensor(input, insets) @@ -298,10 +391,25 @@ class ScaleInputRange(nn.Module): """ def __init__(self, multiplier: float = 1.0) -> None: + """ + Args: + + multiplier (float, optional): A float value used to scale the input. + """ super().__init__() self.multiplier = multiplier def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Scale an input tensor's values. + + Args: + + x (torch.Tensor): Input to scale values of. + + Returns: + **tensor** (torch.Tensor): tensor with it's values scaled. + """ return x * self.multiplier @@ -311,6 +419,16 @@ class RGBToBGR(nn.Module): """ def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Perform RGB to BGR conversion on an input + + Args: + + x (torch.Tensor): RGB image tensor to convert to BGR. + + Returns: + **BGR tensor** (torch.Tensor): A BGR tensor. + """ assert x.dim() == 4 assert x.size(1) == 3 return x[:, [2, 1, 0]] @@ -354,13 +472,6 @@ class GaussianSmoothing(nn.Module): Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. - Arguments: - channels (int, sequence): Number of channels of the input tensors. Output will - have this number of channels as well. - kernel_size (int, sequence): Size of the gaussian kernel. - sigma (float, sequence): Standard deviation of the gaussian kernel. - dim (int, optional): The number of dimensions of the data. - Default value is 2 (spatial). """ def __init__( @@ -370,6 +481,16 @@ def __init__( sigma: Union[float, Sequence[float]], dim: int = 2, ) -> None: + """ + Args: + + channels (int, sequence): Number of channels of the input tensors. Output + will have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ super().__init__() if isinstance(kernel_size, numbers.Number): kernel_size = [kernel_size] * dim @@ -414,10 +535,13 @@ def __init__( def forward(self, input: torch.Tensor) -> torch.Tensor: """ Apply gaussian filter to input. - Arguments: + + Args: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: - filtered (torch.Tensor): Filtered output. + **filtered** (torch.Tensor): Filtered output. """ return self.conv(input, weight=self.weight, groups=self.groups) @@ -431,6 +555,16 @@ class SymmetricPadding(torch.autograd.Function): def forward( ctx: torch.autograd.Function, x: torch.Tensor, padding: List[List[int]] ) -> torch.Tensor: + """ + Apply NumPy symmetric padding to an input tensor while preserving the gradient. + + Args: + + x (torch.Tensor): Input to apply symmetric padding on. + + Returns: + **tensor** (torch.Tensor): Padded tensor. + """ ctx.padding = padding x_device = x.device x = x.cpu() @@ -444,6 +578,16 @@ def forward( def backward( ctx: torch.autograd.Function, grad_output: torch.Tensor ) -> Tuple[torch.Tensor, None]: + """ + Crop away symmetric padding. + + Args: + + grad_output (torch.Tensor): Input to remove symmetric padding from. + + Returns: + **grad_input** (torch.Tensor): Unpadded tensor. + """ grad_input = grad_output.clone() B, C, H, W = grad_input.size() b1, b2 = ctx.padding[0] @@ -460,26 +604,44 @@ class NChannelsToRGB(nn.Module): """ def __init__(self, warp: bool = False) -> None: + """ + Args: + + warp (bool, optional): Whether or not to make the resulting RGB colors more + distict from each other. Default is set to False. + """ super().__init__() self.warp = warp def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Reduce any number of channels down to 3. + + Args: + + x (torch.Tensor): Input to reduce channel dimensions on. + + Returns: + **3 channel RGB tensor** (torch.Tensor): RGB image tensor. + """ assert x.dim() == 4 return nchannels_to_rgb(x, self.warp) class RandomCrop(nn.Module): """ - Randomly crop out a specific size from an NCHW image tensor. - ​ - Args: - crop_size (int, sequence, int): The desired cropped output size. + Randomly crop out a specific size from an NCHW image tensor. """ def __init__( self, crop_size: IntSeqOrIntType, ) -> None: + """ + Args: + + crop_size (int, sequence, int): The desired cropped output size. + """ super().__init__() crop_size = [crop_size] * 2 if not hasattr(crop_size, "__iter__") else crop_size crop_size = list(crop_size) * 2 if len(crop_size) == 1 else crop_size diff --git a/captum/optim/_utils/circuits.py b/captum/optim/_utils/circuits.py index dfe97b204e..3dc2f3e524 100644 --- a/captum/optim/_utils/circuits.py +++ b/captum/optim/_utils/circuits.py @@ -30,8 +30,8 @@ def extract_expanded_weights( specified for target2. target2 (nn.Module): The end target layer. Must be above the layer specified for target1. - crop_shape (int or tuple of ints, optional): Specify the output weight - size to enter crop away padding. + crop_shape (int or tuple of ints, optional): Specify the exact output size + to crop out. model_input (tensor or tuple of tensors, optional): The input to use with the specified model. crop_func (Callable, optional): Specify a function to crop away the padding diff --git a/captum/optim/_utils/image/dataset.py b/captum/optim/_utils/image/dataset.py index 69a2be3453..fcc6d03742 100644 --- a/captum/optim/_utils/image/dataset.py +++ b/captum/optim/_utils/image/dataset.py @@ -3,7 +3,12 @@ def image_cov(tensor: torch.Tensor) -> torch.Tensor: """ - Calculate a tensor's RGB covariance matrix + Calculate a tensor's RGB covariance matrix. + + Args: + tensor (tensor): An NCHW image tensor. + Returns: + *tensor*: An RGB covariance matrix for the specified tensor. """ tensor = tensor.reshape(-1, 3) @@ -14,6 +19,12 @@ def image_cov(tensor: torch.Tensor) -> torch.Tensor: def dataset_cov_matrix(loader: torch.utils.data.DataLoader) -> torch.Tensor: """ Calculate the covariance matrix for an image dataset. + + Args: + loader (torch.utils.data.DataLoader): The reference to a PyTorch + dataloader instance. + Returns: + *tensor*: A covariance matrix for the specified dataset. """ cov_mtx = torch.zeros(3, 3) @@ -30,6 +41,13 @@ def cov_matrix_to_klt( ) -> torch.Tensor: """ Convert a cov matrix to a klt matrix. + + Args: + cov_mtx (tensor): A 3 by 3 covariance matrix generated from a dataset. + normalize (bool): Whether or not to normalize the resulting KLT matrix. + epsilon (float): + Returns: + *tensor*: A KLT matrix for the specified covariance matrix. """ U, S, V = torch.svd(cov_mtx) @@ -47,6 +65,13 @@ def dataset_klt_matrix( a Karhunen-Loève transform (KLT) matrix, for a dataset. The color correlation matrix can then used in color decorrelation transforms for models trained on the dataset. + + Args: + loader (torch.utils.data.DataLoader): The reference to a PyTorch + dataloader instance. + normalize (bool): Whether or not to normalize the resulting KLT matrix. + Returns: + *tensor*: A KLT matrix for the specified dataset. """ cov_mtx = dataset_cov_matrix(loader) diff --git a/captum/optim/_utils/reducer.py b/captum/optim/_utils/reducer.py index 33b0fb13dd..2696d003d6 100644 --- a/captum/optim/_utils/reducer.py +++ b/captum/optim/_utils/reducer.py @@ -16,12 +16,20 @@ class ChannelReducer: """ - Dimensionality reduction for the channel dimension of an input. - The default reduction_alg is NMF from sklearn, which requires users - to put input on CPU before passing to fit_transform. - + Dimensionality reduction for the channel dimension of an input tensor. Olah, et al., "The Building Blocks of Interpretability", Distill, 2018. - See: https://distill.pub/2018/building-blocks/ + + See here for more information: https://distill.pub/2018/building-blocks/ + + Args: + n_components (int, optional): The number of channels to reduce the target + dimension to. + reduction_alg (str or callable, optional): The desired dimensionality + reduction algorithm to use. The default reduction_alg is set to NMF from + sklearn, which requires users to put inputs on CPU before passing them to + fit_transform. + **kwargs (optional): Arbitrary keyword arguments used by the specified + reduction_alg. """ def __init__( @@ -63,9 +71,13 @@ def fit_transform( ) -> torch.Tensor: """ Perform dimensionality reduction on an input tensor. - - If swap_2nd_and_last_dims is true, input channels are expected to be in the - second dimension unless the input tensor has a shape of CHW. + Args: + tensor (tensor): A tensor to perform dimensionality reduction on. + swap_2nd_and_last_dims (bool, optional): If true, input channels are + expected to be in the second dimension unless the input tensor has a + shape of CHW. Default is set to True. + Returns: + *tensor*: A tensor with one of it's dimensions reduced. """ if x.dim() == 3 and swap_2nd_and_last_dims: @@ -115,8 +127,15 @@ def __dir__(self) -> List: def posneg(x: torch.Tensor, dim: int = 0) -> torch.Tensor: """ - Hack that makes a matrix positive by concatination in order to simulate - one-sided NMF with regular NMF + Hack that makes a matrix positive by concatination in order to simulate one-sided + NMF with regular NMF. + + Args: + x (tensor): A tensor to make positive. + dim (int, optional): The dimension to concatinate the two tensor halves at. + Returns: + tensor (torch.tensor): A positive tensor for one-sided dimensionality + reduction. """ return torch.cat([F.relu(x), F.relu(-x)], dim=dim) diff --git a/tests/optim/helpers/image_dataset.py b/tests/optim/helpers/image_dataset.py index 9b5e73ad48..a8cef03b87 100644 --- a/tests/optim/helpers/image_dataset.py +++ b/tests/optim/helpers/image_dataset.py @@ -5,6 +5,14 @@ class ImageTestDataset(torch.utils.data.Dataset): + """ + Create a simple tensor dataset for testing image dataset classes + and functions. + + Args: + tensors (list): A list of tensors to use in the dataset. + """ + def __init__(self, tensors: List[torch.Tensor]) -> None: assert all(t.size(0) == 1 for t in tensors if t.dim() == 4) @@ -23,7 +31,12 @@ def __len__(self) -> int: def image_cov_np(array: np.ndarray) -> np.ndarray: """ - Calculate an array's RGB covariance matrix + Calculate an array's RGB covariance matrix. + + Args: + array (array): An NCHW image array. + Returns: + *array*: An RGB covariance matrix for the specified array. """ array = array.reshape(-1, 3) @@ -36,6 +49,13 @@ def cov_matrix_to_klt_np( ) -> np.ndarray: """ Convert a cov matrix to a klt matrix. + + Args: + cov_mtx (array): A 3 by 3 covariance matrix generated from a dataset. + normalize (bool): Whether or not to normalize the resulting KLT matrix. + epsilon (float): + Returns: + *array*: A KLT matrix for the specified covariance matrix. """ U, S, V = np.linalg.svd(cov_mtx) diff --git a/tests/optim/helpers/numpy_common.py b/tests/optim/helpers/numpy_common.py index 6013600eb7..b432829694 100644 --- a/tests/optim/helpers/numpy_common.py +++ b/tests/optim/helpers/numpy_common.py @@ -12,6 +12,13 @@ def weights_to_heatmap_2d( By default red represents excitatory values, blue represents inhibitory values, and white represents no excitation or inhibition. + + Args: + weight (array): A 2d array to create the heatmap from. + colors (List of strings): A list of strings containing color + hex values to use for coloring the heatmap. + Returns: + *array*: A weight heatmap. """ assert array.ndim == 2 diff --git a/tests/optim/helpers/numpy_transforms.py b/tests/optim/helpers/numpy_transforms.py index 386478f236..eec0afebac 100644 --- a/tests/optim/helpers/numpy_transforms.py +++ b/tests/optim/helpers/numpy_transforms.py @@ -8,7 +8,11 @@ class BlendAlpha: """ - NumPy version of the BlendAlpha transform + NumPy version of the BlendAlpha transform. + + Args: + background (array, optional): An NCHW image array to be used as the + Alpha channel's background. """ def __init__(self, background: Optional[np.ndarray] = None) -> None: @@ -16,6 +20,14 @@ def __init__(self, background: Optional[np.ndarray] = None) -> None: self.background = background def blend_alpha(self, x: np.ndarray) -> np.ndarray: + """ + Blend the Alpha channel into the RGB channels. + + Args: + x (array): RGBA image array to blend into an RGB image array. + Returns: + blended (array): RGB image array. + """ assert x.shape[1] == 4 assert x.ndim == 4 rgb, alpha = x[:, :3, ...], x[:, 3:4, ...] @@ -30,7 +42,11 @@ def blend_alpha(self, x: np.ndarray) -> np.ndarray: class RandomSpatialJitter: """ - NumPy version of the RandomSpatialJitter transform + NumPy version of the RandomSpatialJitter transform. + + Args: + translate (int): The amount to translate the H and W dimensions + of an CHW or NCHW array. """ def __init__(self, translate: int) -> None: @@ -59,6 +75,7 @@ def jitter(self, x: np.ndarray) -> np.ndarray: class CenterCrop: """ Center crop a specified amount from a tensor. + Arguments: size (int or sequence of int): Number of pixels to center crop away. pixels_from_edges (bool, optional): Whether to treat crop size values @@ -181,6 +198,15 @@ def __init__(self, transform: Union[str, np.ndarray] = "klt") -> None: ) def to_rgb(self, x: np.ndarray, inverse: bool = False) -> np.ndarray: + """ + Args: + x (array): A CHW or NCHW RGB or RGBA image array. + inverse (bool, optional): Whether to recorrelate or decorrelate colors. + Default is set to False. + Returns: + *array*: An array with it's colors recorrelated or decorrelated. + """ + assert x.ndim == 3 or x.ndim == 4 # alpha channel is taken off... diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 525d6277aa..7c420aa579 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -337,7 +337,7 @@ def test_sharedimage_get_offset_single_number(self) -> None: shapes=shared_shapes, parameterization=test_param ) - offset = image_param.get_offset(4, 3) + offset = image_param._get_offset(4, 3) self.assertEqual(len(offset), 3) self.assertEqual(offset, [[4, 4, 4, 4]] * 3) @@ -354,7 +354,7 @@ def test_sharedimage_get_offset_exact(self) -> None: ) offset_vals = ((1, 2, 3, 4), (4, 3, 2, 1), (1, 2, 3, 4)) - offset = image_param.get_offset(offset_vals, 3) + offset = image_param._get_offset(offset_vals, 3) self.assertEqual(len(offset), 3) self.assertEqual(offset, [[int(o) for o in v] for v in offset_vals]) @@ -371,7 +371,7 @@ def test_sharedimage_get_offset_single_set_four_numbers(self) -> None: ) offset_vals = (1, 2, 3, 4) - offset = image_param.get_offset(offset_vals, 3) + offset = image_param._get_offset(offset_vals, 3) self.assertEqual(len(offset), 3) self.assertEqual(offset, [list(offset_vals)] * 3) @@ -388,7 +388,7 @@ def test_sharedimage_get_offset_single_set_three_numbers(self) -> None: ) offset_vals = (2, 3, 4) - offset = image_param.get_offset(offset_vals, 3) + offset = image_param._get_offset(offset_vals, 3) self.assertEqual(len(offset), 3) self.assertEqual(offset, [[0] + list(offset_vals)] * 3) @@ -405,7 +405,7 @@ def test_sharedimage_get_offset_single_set_two_numbers(self) -> None: ) offset_vals = (3, 4) - offset = image_param.get_offset(offset_vals, 3) + offset = image_param._get_offset(offset_vals, 3) self.assertEqual(len(offset), 3) self.assertEqual(offset, [[0, 0] + list(offset_vals)] * 3) @@ -448,7 +448,7 @@ def test_apply_offset(self): ) test_x_list = [torch.ones(*size) for x in range(size[0])] - output_A = image_param.apply_offset(test_x_list) + output_A = image_param._apply_offset(test_x_list) x_list = [torch.ones(*size) for x in range(size[0])] self.assertEqual(image_param.offset, [list(offset_vals)]) @@ -475,7 +475,7 @@ def test_interpolate_tensor(self) -> None: batch = 1 test_tensor = torch.ones(6, 4, 128, 128) - output_tensor = image_param.interpolate_tensor( + output_tensor = image_param._interpolate_tensor( test_tensor, batch, channels, size[0], size[1] )