diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index 0016b87998..3dfc54c58d 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -16,6 +16,11 @@ def _make_arg_str(arg: Any) -> str: return arg[:15] + "..." if too_big else arg +# Reduction op for CompositeLoss loss composability size mismatch avoidance +# REDUCTION_OP is only used for binary math operations using two Loss instances +REDUCTION_OP: Callable[[torch.Tensor], torch.Tensor] = torch.mean + + class Loss(ABC): """ Abstract Class to describe loss. @@ -40,6 +45,12 @@ def __repr__(self) -> str: def __neg__(self) -> "CompositeLoss": return module_op(self, None, operator.neg) + def __pos__(self) -> "CompositeLoss": + return module_op(self, None, operator.pos) + + def __abs__(self) -> "CompositeLoss": + return module_op(self, None, operator.abs) + def __add__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": return module_op(self, other, operator.add) @@ -52,6 +63,9 @@ def __mul__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": def __truediv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": return module_op(self, other, operator.truediv) + def __floordiv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": + return module_op(self, other, operator.floordiv) + def __pow__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": return module_op(self, other, operator.pow) @@ -65,40 +79,58 @@ def __rmul__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": return self.__mul__(other) def __rtruediv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": - if isinstance(other, (int, float)): + rmodule_op(self, other, operator.truediv) - def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return operator.truediv(other, torch.mean(self(module))) + def __rfloordiv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": + rmodule_op(self, other, operator.floordiv) + + def __rpow__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": + rmodule_op(self, other, operator.pow) + + def mean( + self, dim: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False + ) -> "CompositeLoss": + """ + Composable torch.mean reduction operator. See torch.mean for more details: + https://pytorch.org/docs/stable/generated/torch.mean.html - name = self.__name__ - target = self.target - elif isinstance(other, Loss): - # This should never get called because __div__ will be called instead - pass + Args: + dim (int or tuple of int, optional): The dimension or dimensions to reduce. + Default: None for all dimension. + keepdim (bool, optional): Whether the output tensor has dim retained or + not. + Default: False + + Returns: + composite_loss (CompositeLoss): A composable loss instance. + """ + if dim is None: + return custom_composable_op(self, torch.mean) else: - raise TypeError( - "Can only apply math operations with int, float or Loss. Received type " - + str(type(other)) - ) - return CompositeLoss(loss_fn, name=name, target=target) + return custom_composable_op(self, torch.mean, dim=dim, keepdim=keepdim) - def __rpow__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": - if isinstance(other, (int, float)): + def sum( + self, dim: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False + ) -> "CompositeLoss": + """ + Composable torch.sum reduction operator. See torch.sum for more details: + https://pytorch.org/docs/stable/generated/torch.sum.html + + Args: - def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return operator.pow(other, torch.mean(self(module))) + dim (int or tuple of int, optional): The dimension or dimensions to reduce. + Default: None for all dimension. + keepdim (bool, optional): Whether the output tensor has dim retained or + not. + Default: False - name = self.__name__ - target = self.target - elif isinstance(other, Loss): - # This should never get called because __pow__ will be called instead - pass + Returns: + composite_loss (CompositeLoss): A composable loss instance. + """ + if dim is None: + return custom_composable_op(self, torch.sum) else: - raise TypeError( - "Can only apply math operations with int, float or Loss. Received type " - + str(type(other)) - ) - return CompositeLoss(loss_fn, name=name, target=target) + return custom_composable_op(self, torch.sum, dim=dim, keepdim=keepdim) def module_op( @@ -107,7 +139,7 @@ def module_op( """ This is a general function for applying math operations to Losses """ - if other is None and math_op == operator.neg: + if other is None and math_op in [operator.neg, operator.pos, operator.abs]: def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: return math_op(self(module)) @@ -124,7 +156,7 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: elif isinstance(other, Loss): # We take the mean of the output tensor to resolve shape mismatches def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return math_op(torch.mean(self(module)), torch.mean(other(module))) + return math_op(REDUCTION_OP(self(module)), REDUCTION_OP(other(module))) name = f"Compose({', '.join([self.__name__, other.__name__])})" target = ( @@ -138,18 +170,110 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: return CompositeLoss(loss_fn, name=name, target=target) +def rmodule_op( + self: Loss, other: Union[int, float, "Loss"], math_op: Callable +) -> "CompositeLoss": + """ + This is a general function for applying the "r" versions of math operations to + Losses. + """ + if isinstance(other, (int, float)): + + def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + return math_op(other, self(module)) + + name = self.__name__ + target = self.target + elif isinstance(other, Loss): + # This should never get called because __math_op__ will be called instead + pass + else: + raise TypeError( + "Can only apply math operations with int, float or Loss. Received type " + + str(type(other)) + ) + return CompositeLoss(loss_fn, name=name, target=target) + + +def custom_composable_op( + loss, + loss_op_fn: Callable, + *args: Any, + **kwargs: Any, +) -> "CompositeLoss": + """ + Implement composability for operations that take a single tensor or list of tensors + and then return a single tensor. Custom user defined functions can be used in + addition to some built-in Python functions and PyTorch operations. + + Args: + + loss (Loss or list of Loss): A loss objective or list of loss objectives. + loss_op_fn (Callable): A supported PyTorch, Python, or custom function. + Default: torch.mean + args (Any, optional): Any additional arguments to pass to loss_op_fn. + kwargs (Any, optional): Any additional arguments to pass to loss_op_fn. + to_scalar_fn (Callable, optional): A function for converting loss function + outputs to scalar values, in order to prevent size mismatches. This is + variable only used if more than one loss is given. + Default: None + + Returns: + composite_loss (CompositeLoss): A composable loss instance. + """ + + if isinstance(loss, (tuple, list)): + if "to_scalar_fn" not in kwargs: + to_scalar_fn = None + else: + to_scalar_fn = kwargs["to_scalar_fn"] + del kwargs["to_scalar_fn"] + + def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + loss_tensors = [loss_obj(module) for loss_obj in loss] + if to_scalar_fn is not None: + loss_tensors = [to_scalar_fn(tensor) for tensor in loss_tensors] + return loss_op_fn(loss_tensors, *args, **kwargs) + + name_list = ", ".join([loss_obj.__name__ for loss_obj in loss]) + name = loss_op_fn.__name__ + "(" + name_list + ")" + + # Collect targets from losses + target = [ + target + for targets in [ + [loss_obj.target] + if not hasattr(loss_obj.target, "__iter__") + else loss_obj.target + for loss_obj in loss + ] + for target in targets + ] + else: + + def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + return loss_op_fn(loss(module), *args, **kwargs) + + name = loss_op_fn.__name__ + "(" + loss.__name__ + ")" + target = loss.target + return CompositeLoss(loss_fn, name=name, target=target) + + class BaseLoss(Loss): def __init__( self, target: Union[nn.Module, List[nn.Module]] = [], - batch_index: Optional[int] = None, + batch_index: Optional[Union[int, List[int]]] = None, ) -> None: super(BaseLoss, self).__init__() self._target = target if batch_index is None: self._batch_index = (None, None) + elif isinstance(batch_index, (list, tuple)): + self._batch_index = tuple(batch_index) else: self._batch_index = (batch_index, batch_index + 1) + assert all([isinstance(b, (int, type(None))) for b in self._batch_index]) @property def target(self) -> Union[nn.Module, List[nn.Module]]: @@ -343,6 +467,7 @@ class Diversity(BaseLoss): def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] + activations = activations[self.batch_index[0] : self.batch_index[1]] batch, channels = activations.shape[:2] flattened = activations.view(batch, channels, -1) grams = torch.matmul(flattened, torch.transpose(flattened, 1, 2)) @@ -417,12 +542,18 @@ class Alignment(BaseLoss): https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons """ - def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None: - BaseLoss.__init__(self, target) + def __init__( + self, + target: nn.Module, + decay_ratio: float = 2.0, + batch_index: Optional[List[int]] = None, + ) -> None: + BaseLoss.__init__(self, target, batch_index) self.decay_ratio = decay_ratio def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] + activations = activations[self.batch_index[0] : self.batch_index[1]] B = activations.size(0) sum_tensor = torch.zeros(1, device=activations.device) @@ -737,6 +868,7 @@ def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor: __all__ = [ "Loss", + "REDUCTION_OP", "loss_wrapper", "BaseLoss", "LayerActivation", diff --git a/tests/optim/core/test_loss.py b/tests/optim/core/test_loss.py index de38f9bdda..ebdb962b34 100644 --- a/tests/optim/core/test_loss.py +++ b/tests/optim/core/test_loss.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import unittest from typing import List, Union import numpy as np @@ -193,6 +194,25 @@ def test_negative(self) -> None: get_loss_value(model, loss), -CHANNEL_ACTIVATION_0_LOSS, places=6 ) + def test_positive(self) -> None: + if torch.__version__ <= "1.3.0": + raise unittest.SkipTest( + "Skipping postive CompositeLoss test due to insufficient" + + " Torch version." + ) + model = BasicModel_ConvNet_Optim() + loss = +opt_loss.ChannelActivation(model.layer, 0) + self.assertAlmostEqual( + get_loss_value(model, loss), CHANNEL_ACTIVATION_0_LOSS, places=6 + ) + + def test_abs(self) -> None: + model = BasicModel_ConvNet_Optim() + loss = abs(-opt_loss.ChannelActivation(model.layer, 0)) + self.assertAlmostEqual( + get_loss_value(model, loss), CHANNEL_ACTIVATION_0_LOSS, places=6 + ) + def test_addition(self) -> None: model = BasicModel_ConvNet_Optim() loss = ( @@ -250,6 +270,13 @@ def test_division(self) -> None: # model.layer, 1 # ) + def test_floor_division(self) -> None: + model = BasicModel_ConvNet_Optim() + loss = opt_loss.ChannelActivation(model.layer, 0) // 10 + self.assertAlmostEqual( + get_loss_value(model, loss), CHANNEL_ACTIVATION_0_LOSS // 10 + ) + def test_pow(self) -> None: model = BasicModel_ConvNet_Optim() loss = opt_loss.ChannelActivation(model.layer, 0) ** 2 @@ -268,6 +295,83 @@ def test_pow(self) -> None: # model.layer, 1 # ) + def test_sum(self) -> None: + model = torch.nn.Identity() + loss = opt_loss.LayerActivation(model).sum() + self.assertAlmostEqual(get_loss_value(model, loss), 3.0, places=1) + + def test_mean(self) -> None: + model = torch.nn.Identity() + loss = opt_loss.LayerActivation(model).mean() + self.assertAlmostEqual(get_loss_value(model, loss), 1.0, places=1) + + +class TestCompositeLossReductionOP(BaseTest): + def test_reduction_op(self) -> None: + self.assertEqual(opt_loss.REDUCTION_OP, torch.mean) + + +def TestCustomComposableOP(BaseTest): + def test_torch_sum(self) -> None: + model = torch.nn.Identity() + loss = opt_loss.LayerActivation(model) + loss = opt_loss.custom_composable_op(loss, loss_op_fn=torch.sum) + self.assertAlmostEqual(get_loss_value(model, loss), 3.0, places=1) + + def test_sum_list_with_scalar_fn(self) -> None: + model = torch.nn.Identity() + loss_list = [ + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + ] + loss = opt_loss.custom_composable_op( + loss_list, loss_op_fn=sum, to_scalar_fn=torch.mean + ) + self.assertAlmostEqual(get_loss_value(model, loss), 5.0, places=1) + + def test_custom_op(self) -> None: + def custom_op_fn( + losses: torch.Tensor, add_val: float = 1.0, mul_val: float = 1.0 + ) -> torch.Tensor: + return torch.sum(losses) + add_val * mul_val + + model = torch.nn.Identity() + loss = opt_loss.LayerActivation(model) + + loss = opt_loss.custom_composable_op( + loss, loss_op_fn=custom_op_fn, add_val=2.0, mul_val=2.0 + ) + self.assertAlmostEqual(get_loss_value(model, loss), 7.0, places=1) + + def test_custom_op_list(self) -> None: + def custom_op_list_fn( + losses: List[torch.Tensor], add_val: float = 1.0, mul_val: float = 1.0 + ) -> torch.Tensor: + return torch.cat( + [torch.sum(loss) + add_val * mul_val for loss in losses], 0 + ).sum() + + model = torch.nn.Identity() + loss_list = [ + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + ] + loss = opt_loss.custom_composable_op( + loss_list, + loss_op_fn=custom_op_list_fn, + add_val=2.0, + mul_val=2.0, + ) + self.assertAlmostEqual(get_loss_value(model, loss), 35.0, places=1) + + +def TestSumLossList(BaseTest): def test_sum_loss_list(self) -> None: n_batch = 400 model = torch.nn.Identity()