Skip to content
196 changes: 164 additions & 32 deletions captum/optim/_core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def _make_arg_str(arg: Any) -> str:
return arg[:15] + "..." if too_big else arg


# Reduction op for CompositeLoss loss composability size mismatch avoidance
# REDUCTION_OP is only used for binary math operations using two Loss instances
REDUCTION_OP: Callable[[torch.Tensor], torch.Tensor] = torch.mean


class Loss(ABC):
"""
Abstract Class to describe loss.
Expand All @@ -40,6 +45,12 @@ def __repr__(self) -> str:
def __neg__(self) -> "CompositeLoss":
return module_op(self, None, operator.neg)

def __pos__(self) -> "CompositeLoss":
return module_op(self, None, operator.pos)

def __abs__(self) -> "CompositeLoss":
return module_op(self, None, operator.abs)

def __add__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss":
return module_op(self, other, operator.add)

Expand All @@ -52,6 +63,9 @@ def __mul__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss":
def __truediv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss":
return module_op(self, other, operator.truediv)

def __floordiv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss":
return module_op(self, other, operator.floordiv)

def __pow__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss":
return module_op(self, other, operator.pow)

Expand All @@ -65,40 +79,58 @@ def __rmul__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss":
return self.__mul__(other)

def __rtruediv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss":
if isinstance(other, (int, float)):
rmodule_op(self, other, operator.truediv)

def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
return operator.truediv(other, torch.mean(self(module)))
def __rfloordiv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss":
rmodule_op(self, other, operator.floordiv)

def __rpow__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss":
rmodule_op(self, other, operator.pow)

def mean(
self, dim: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False
) -> "CompositeLoss":
"""
Composable torch.mean reduction operator. See torch.mean for more details:
https://pytorch.org/docs/stable/generated/torch.mean.html

name = self.__name__
target = self.target
elif isinstance(other, Loss):
# This should never get called because __div__ will be called instead
pass
Args:
dim (int or tuple of int, optional): The dimension or dimensions to reduce.
Default: None for all dimension.
keepdim (bool, optional): Whether the output tensor has dim retained or
not.
Default: False

Returns:
composite_loss (CompositeLoss): A composable loss instance.
"""
if dim is None:
return custom_composable_op(self, torch.mean)
else:
raise TypeError(
"Can only apply math operations with int, float or Loss. Received type "
+ str(type(other))
)
return CompositeLoss(loss_fn, name=name, target=target)
return custom_composable_op(self, torch.mean, dim=dim, keepdim=keepdim)

def __rpow__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss":
if isinstance(other, (int, float)):
def sum(
self, dim: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False
) -> "CompositeLoss":
"""
Composable torch.sum reduction operator. See torch.sum for more details:
https://pytorch.org/docs/stable/generated/torch.sum.html

Args:

def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
return operator.pow(other, torch.mean(self(module)))
dim (int or tuple of int, optional): The dimension or dimensions to reduce.
Default: None for all dimension.
keepdim (bool, optional): Whether the output tensor has dim retained or
not.
Default: False

name = self.__name__
target = self.target
elif isinstance(other, Loss):
# This should never get called because __pow__ will be called instead
pass
Returns:
composite_loss (CompositeLoss): A composable loss instance.
"""
if dim is None:
return custom_composable_op(self, torch.sum)
else:
raise TypeError(
"Can only apply math operations with int, float or Loss. Received type "
+ str(type(other))
)
return CompositeLoss(loss_fn, name=name, target=target)
return custom_composable_op(self, torch.sum, dim=dim, keepdim=keepdim)


def module_op(
Expand All @@ -107,7 +139,7 @@ def module_op(
"""
This is a general function for applying math operations to Losses
"""
if other is None and math_op == operator.neg:
if other is None and math_op in [operator.neg, operator.pos, operator.abs]:

def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
return math_op(self(module))
Expand All @@ -124,7 +156,7 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
elif isinstance(other, Loss):
# We take the mean of the output tensor to resolve shape mismatches
def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
return math_op(torch.mean(self(module)), torch.mean(other(module)))
return math_op(REDUCTION_OP(self(module)), REDUCTION_OP(other(module)))

name = f"Compose({', '.join([self.__name__, other.__name__])})"
target = (
Expand All @@ -138,18 +170,110 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
return CompositeLoss(loss_fn, name=name, target=target)


def rmodule_op(
self: Loss, other: Union[int, float, "Loss"], math_op: Callable
) -> "CompositeLoss":
"""
This is a general function for applying the "r" versions of math operations to
Losses.
"""
if isinstance(other, (int, float)):

def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
return math_op(other, self(module))

name = self.__name__
target = self.target
elif isinstance(other, Loss):
# This should never get called because __math_op__ will be called instead
pass
else:
raise TypeError(
"Can only apply math operations with int, float or Loss. Received type "
+ str(type(other))
)
return CompositeLoss(loss_fn, name=name, target=target)


def custom_composable_op(
loss,
loss_op_fn: Callable,
*args: Any,
**kwargs: Any,
) -> "CompositeLoss":
"""
Implement composability for operations that take a single tensor or list of tensors
and then return a single tensor. Custom user defined functions can be used in
addition to some built-in Python functions and PyTorch operations.

Args:

loss (Loss or list of Loss): A loss objective or list of loss objectives.
loss_op_fn (Callable): A supported PyTorch, Python, or custom function.
Default: torch.mean
args (Any, optional): Any additional arguments to pass to loss_op_fn.
kwargs (Any, optional): Any additional arguments to pass to loss_op_fn.
to_scalar_fn (Callable, optional): A function for converting loss function
outputs to scalar values, in order to prevent size mismatches. This is
variable only used if more than one loss is given.
Default: None

Returns:
composite_loss (CompositeLoss): A composable loss instance.
"""

if isinstance(loss, (tuple, list)):
if "to_scalar_fn" not in kwargs:
to_scalar_fn = None
else:
to_scalar_fn = kwargs["to_scalar_fn"]
del kwargs["to_scalar_fn"]

def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
loss_tensors = [loss_obj(module) for loss_obj in loss]
if to_scalar_fn is not None:
loss_tensors = [to_scalar_fn(tensor) for tensor in loss_tensors]
return loss_op_fn(loss_tensors, *args, **kwargs)

name_list = ", ".join([loss_obj.__name__ for loss_obj in loss])
name = loss_op_fn.__name__ + "(" + name_list + ")"

# Collect targets from losses
target = [
target
for targets in [
[loss_obj.target]
if not hasattr(loss_obj.target, "__iter__")
else loss_obj.target
for loss_obj in loss
]
for target in targets
]
else:

def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
return loss_op_fn(loss(module), *args, **kwargs)

name = loss_op_fn.__name__ + "(" + loss.__name__ + ")"
target = loss.target
return CompositeLoss(loss_fn, name=name, target=target)


class BaseLoss(Loss):
def __init__(
self,
target: Union[nn.Module, List[nn.Module]] = [],
Copy link
Contributor

@aobo-y aobo-y Apr 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the target can be List[nn.Module], many losses below cannot directly use it as dict key targets_to_values[self.target]. Did I miss anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aobo-y Losses like ActivationInterpolation have multiple targets (Faceted loss as well in an upcoming PR), but BaseLoss works off using a single target variable.

The BaseLoss class is called in the __init__ functions of loss classes like so:

# Single target
BaseLoss.__init__(self, target, batch_index)

# Multiple targets
BaseLoss.__init__(self, [target1, target2])

The loss class itself will indicate via target: List[nn.Module] type hint that multiple targets are supported / required, or it is handled things internally by passing the targets as a list to BaseLoss like in ActivationInterpolation.

The ActivationInterpolation loss class can be found here: https://github.com/ProGamerGov/captum/blob/optim-wip-composable-loss-improvements/captum/optim/_core/loss.py#L506

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but cases like DeepDream and some others directly inherits BaseLoss's init definition, where target can be a list while actually it should not https://github.com/ProGamerGov/captum/blob/optim-wip-composable-loss-improvements/captum/optim/_core/loss.py#L393-L407

If these losses have different assumptions of what their targets should be, why do we abstract the target into the base class. The base class BaseLoss does not need target anyway. Each class can define their own target in __init__. Or we can have 2 other intermediate abstract classes SingleTargetLoss MultiTargetsLoss

But anyway, this is just for discussion. It has nth related to this PR. We can leave it to future updates if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah I see what you mean now. In the original code, I think that Ludwig had SingleTargetObjective & MultiObjective for handling these cases: https://github.com/ludwigschubert/captum/blob/f1fd0729dece59564a7c10b7b397617d8a09a247/captum/optim/optim/objectives.py#L108

It'd probably be best to leave this to a future PR if decide on the changes

batch_index: Optional[int] = None,
batch_index: Optional[Union[int, List[int]]] = None,
) -> None:
super(BaseLoss, self).__init__()
self._target = target
if batch_index is None:
self._batch_index = (None, None)
elif isinstance(batch_index, (list, tuple)):
self._batch_index = tuple(batch_index)
else:
self._batch_index = (batch_index, batch_index + 1)
assert all([isinstance(b, (int, type(None))) for b in self._batch_index])

@property
def target(self) -> Union[nn.Module, List[nn.Module]]:
Expand Down Expand Up @@ -343,6 +467,7 @@ class Diversity(BaseLoss):

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
activations = activations[self.batch_index[0] : self.batch_index[1]]
batch, channels = activations.shape[:2]
flattened = activations.view(batch, channels, -1)
grams = torch.matmul(flattened, torch.transpose(flattened, 1, 2))
Expand Down Expand Up @@ -417,12 +542,18 @@ class Alignment(BaseLoss):
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
"""

def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None:
BaseLoss.__init__(self, target)
def __init__(
self,
target: nn.Module,
decay_ratio: float = 2.0,
batch_index: Optional[List[int]] = None,
) -> None:
BaseLoss.__init__(self, target, batch_index)
self.decay_ratio = decay_ratio

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
activations = activations[self.batch_index[0] : self.batch_index[1]]
B = activations.size(0)

sum_tensor = torch.zeros(1, device=activations.device)
Expand Down Expand Up @@ -737,6 +868,7 @@ def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor:

__all__ = [
"Loss",
"REDUCTION_OP",
"loss_wrapper",
"BaseLoss",
"LayerActivation",
Expand Down
Loading