From da78fb74373cdbfb482e56564fbd886b930a938b Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 7 Apr 2022 11:24:11 -0600 Subject: [PATCH 1/3] Fix duplicated target bug --- captum/optim/_core/loss.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index 6ddaf43300..de2ebe4da4 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -130,6 +130,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: target = (self.target if isinstance(self.target, list) else [self.target]) + ( other.target if isinstance(other.target, list) else [other.target] ) + + # Filter out duplicate targets + target = list(dict.fromkeys(target)) else: raise TypeError( "Can only apply math operations with int, float or Loss. Received type " From f86d7fb8e94f131e92389ed3e6e0cca962115e09 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 7 Apr 2022 11:38:04 -0600 Subject: [PATCH 2/3] Fix duplicated target bug in `sum_loss_list` & `collect_activations` --- captum/optim/_core/loss.py | 3 +++ captum/optim/models/_common.py | 1 + 2 files changed, 4 insertions(+) diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index de2ebe4da4..463d26b71c 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -724,6 +724,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: ] for target in targets ] + + # Filter out duplicate targets + target = list(dict.fromkeys(target)) return CompositeLoss(loss_fn, name=name, target=target) diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index e4b15ab1dc..719635de66 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -193,6 +193,7 @@ def collect_activations( """ if not isinstance(targets, list): targets = [targets] + targets = list(dict.fromkeys(targets)) catch_activ = ActivationFetcher(model, targets) activ_out = catch_activ(model_input) return activ_out From c04a29a147a642e8c39f3367872cf0ea50f5339b Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 22 May 2022 18:17:34 -0600 Subject: [PATCH 3/3] Add ToDo comment for target handling --- captum/optim/_core/loss.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index 12c012ad97..3f8ab79a22 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -126,6 +126,8 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: return math_op(torch.mean(self(module)), torch.mean(other(module))) name = f"Compose({', '.join([self.__name__, other.__name__])})" + + # ToDo: Refine logic for self.target handling target = (self.target if isinstance(self.target, list) else [self.target]) + ( other.target if isinstance(other.target, list) else [other.target] )