diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index aa33793642..ecb82fa72f 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -189,6 +189,14 @@ def wrapper(*args, **kwargs) -> object: class LayerActivation(BaseLoss): """ Maximize activations at the target layer. + This is the most basic loss available and it simply returns the activations in + their original form. + + Args: + target (nn.Module): The layer to optimize for. + batch_index (int, optional): The index of the image to optimize if we + optimizing a batch of images. If unspecified, defaults to all images + in the batch. """ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: @@ -201,6 +209,15 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: class ChannelActivation(BaseLoss): """ Maximize activations at the target layer and target channel. + This loss maximizes the activations of a target channel in a specified target + layer, and can be useful to determine what features the channel is excited by. + + Args: + target (nn.Module): The layer to containing the channel to optimize for. + channel_index (int): The index of the channel to optimize for. + batch_index (int, optional): The index of the image to optimize if we + optimizing a batch of images. If unspecified, defaults to all images + in the batch. """ def __init__( @@ -224,6 +241,26 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: @loss_wrapper class NeuronActivation(BaseLoss): + """ + This loss maximizes the activations of a target neuron in the specified channel + from the specified layer. This loss is useful for determining the type of features + that excite a neuron, and thus is often used for circuits and neuron related + research. + + Args: + target (nn.Module): The layer to containing the channel to optimize for. + channel_index (int): The index of the channel to optimize for. + x (int, optional): The x coordinate of the neuron to optimize for. If + unspecified, defaults to center, or one unit left of center for even + lengths. + y (int, optional): The y coordinate of the neuron to optimize for. If + unspecified, defaults to center, or one unit up of center for even + heights. + batch_index (int, optional): The index of the image to optimize if we + optimizing a batch of images. If unspecified, defaults to all images + in the batch. + """ + def __init__( self, target: nn.Module, @@ -258,6 +295,16 @@ class DeepDream(BaseLoss): """ Maximize 'interestingness' at the target layer. Mordvintsev et al., 2015. + https://github.com/google/deepdream + This loss returns the squared layer activations. When combined with a negative + mean loss summarization, this loss will create hallucinogenic visuals commonly + referred to as 'Deep Dream'. + + Args: + target (nn.Module): The layer to optimize for. + batch_index (int, optional): The index of the image to optimize if we + optimizing a batch of images. If unspecified, defaults to all images + in the batch. """ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: @@ -272,6 +319,15 @@ class TotalVariation(BaseLoss): Total variation denoising penalty for activations. See Mahendran, V. 2014. Understanding Deep Image Representations by Inverting Them. https://arxiv.org/abs/1412.0035 + This loss attempts to smooth / denoise the target by performing total variance + denoising. The target is most often the image that’s being optimized. This loss is + often used to remove unwanted visual artifacts. + + Args: + target (nn.Module): The layer to optimize for. + batch_index (int, optional): The index of the image to optimize if we + optimizing a batch of images. If unspecified, defaults to all images + in the batch. """ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: @@ -286,6 +342,14 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: class L1(BaseLoss): """ L1 norm of the target layer, generally used as a penalty. + + Args: + target (nn.Module): The layer to optimize for. + constant (float): Constant threshold to deduct from the activations. + Defaults to 0. + batch_index (int, optional): The index of the image to optimize if we + optimizing a batch of images. If unspecified, defaults to all images + in the batch. """ def __init__( @@ -307,6 +371,15 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: class L2(BaseLoss): """ L2 norm of the target layer, generally used as a penalty. + + Args: + target (nn.Module): The layer to optimize for. + constant (float): Constant threshold to deduct from the activations. + Defaults to 0. + epsilon (float): Small value to add to L2 prior to sqrt. Defaults to 1e-6. + batch_index (int, optional): The index of the image to optimize if we + optimizing a batch of images. If unspecified, defaults to all images + in the batch. """ def __init__( @@ -334,6 +407,14 @@ class Diversity(BaseLoss): Use a cosine similarity penalty to extract features from a polysemantic neuron. Olah, Mordvintsev & Schubert, 2017. https://distill.pub/2017/feature-visualization/#diversity + This loss helps break up polysemantic layers, channels, and neurons by encouraging + diversity across the different batches. This loss is to be used along with a main + loss. + + Args: + target (nn.Module): The layer to optimize for. + batch_index (int, optional): Unused here since we are optimizing for diversity + across the batch. """ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: @@ -359,6 +440,16 @@ class ActivationInterpolation(BaseLoss): Interpolate between two different layers & channels. Olah, Mordvintsev & Schubert, 2017. https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons + This loss helps to interpolate or mix visualizations from two activations (layer or + channel) by interpolating a linear sum between the two activations. + + Args: + target1 (nn.Module): The first layer to optimize for. + channel_index1 (int): Index of channel in first layer to optimize. Defaults to + all channels. + target2 (nn.Module): The first layer to optimize for. + channel_index2 (int): Index of channel in first layer to optimize. Defaults to + all channels. """ def __init__( @@ -410,6 +501,14 @@ class Alignment(BaseLoss): similarity between them. Olah, Mordvintsev & Schubert, 2017. https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons + When interpolating between activations, it may be desirable to keep image landmarks + in the same position for visual comparison. This loss helps to minimize L2 distance + between neighbouring images. + + Args: + target (nn.Module): The layer to optimize for. + decay_ratio (float): How much to decay penalty as images move apart in batch. + Defaults to 2. """ def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None: @@ -438,6 +537,18 @@ class Direction(BaseLoss): Visualize a general direction vector. Carter, et al., "Activation Atlas", Distill, 2019. https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images + This loss helps to visualize a specific vector direction in a layer, by maximizing + the alignment between the input vector and the layer’s activation vector. The + dimensionality of the vector should correspond to the number of channels in the + layer. + + Args: + target (nn.Module): The layer to optimize for. + vec (torch.Tensor): Vector representing direction to align to. + cossim_pow (float, optional): The desired cosine similarity power to use. + batch_index (int, optional): The index of the image to optimize if we + optimizing a batch of images. If unspecified, defaults to all images + in the batch. """ def __init__( @@ -464,6 +575,23 @@ class NeuronDirection(BaseLoss): Visualize a single (x, y) position for a direction vector. Carter, et al., "Activation Atlas", Distill, 2019. https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images + Extends Direction loss by focusing on visualizing a single neuron within the + kernel. + + Args: + target (nn.Module): The layer to optimize for. + vec (torch.Tensor): Vector representing direction to align to. + x (int, optional): The x coordinate of the neuron to optimize for. If + unspecified, defaults to center, or one unit left of center for even + lengths. + y (int, optional): The y coordinate of the neuron to optimize for. If + unspecified, defaults to center, or one unit up of center for even + heights. + channel_index (int): The index of the channel to optimize for. + cossim_pow (float, optional): The desired cosine similarity power to use. + batch_index (int, optional): The index of the image to optimize if we + optimizing a batch of images. If unspecified, defaults to all images + in the batch. """ def __init__( @@ -505,6 +633,15 @@ class TensorDirection(BaseLoss): Visualize a tensor direction vector. Carter, et al., "Activation Atlas", Distill, 2019. https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images + Extends Direction loss by allowing batch-wise direction visualization. + + Args: + target (nn.Module): The layer to optimize for. + vec (torch.Tensor): Vector representing direction to align to. + cossim_pow (float, optional): The desired cosine similarity power to use. + batch_index (int, optional): The index of the image to optimize if we + optimizing a batch of images. If unspecified, defaults to all images + in the batch. """ def __init__( @@ -542,6 +679,23 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: class ActivationWeights(BaseLoss): """ Apply weights to channels, neurons, or spots in the target. + This loss weighs specific channels or neurons in a given layer, via a weight + vector. + + Args: + target (nn.Module): The layer to optimize for. + weights (torch.Tensor): Weights to apply to targets. + neuron (bool): Whether target is a neuron. Defaults to False. + x (int, optional): The x coordinate of the neuron to optimize for. If + unspecified, defaults to center, or one unit left of center for even + lengths. + y (int, optional): The y coordinate of the neuron to optimize for. If + unspecified, defaults to center, or one unit up of center for even + heights. + wx (int, optional): Length of neurons to apply the weights to, along the + x-axis. + wy (int, optional): Length of neurons to apply the weights to, along the + y-axis. """ def __init__(