diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 5b83b0be2..ee239d5f9 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -759,10 +759,10 @@ def _reduce_list( Applies reduction function to given list. If each element in the list is a Tensor, applies reduction function to all elements of the list, and returns the output Tensor / value. If each element is a boolean, apply any method (or). - If each element is a tuple, applies reduction - function to corresponding elements of each tuple in the list, and returns + If each element is a tuple/list, applies reduction + function to corresponding elements of each tuple/list in the list, and returns tuple of reduction function outputs with length matching the length of tuple - val_list[0]. It is assumed that all tuples in the list have the same length + val_list[0]. It is assumed that all tuples/lists in the list have the same length and red_func can be applied to all elements in each corresponding position. """ assert len(val_list) > 0, "Cannot reduce empty list!" @@ -774,7 +774,7 @@ def _reduce_list( elif isinstance(val_list[0], bool): # pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`. return any(val_list) - elif isinstance(val_list[0], tuple): + elif isinstance(val_list[0], (tuple, list)): final_out = [] # pyre-fixme[6]: For 1st argument expected `pyre_extensions.ReadOnly[Sized]` # but got `TupleOrTensorOrBoolGeneric`. @@ -786,7 +786,7 @@ def _reduce_list( else: raise AssertionError( "Elements to be reduced can only be" - "either Tensors or tuples containing Tensors." + "either Tensors or tuples/lists containing Tensors." ) # pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `Tuple[Any, ...]`. return tuple(final_out) diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index fba9c1ba7..8cf59e6dd 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -46,8 +46,8 @@ def apply_gradient_requirements( a tensor originally required grad is returned. """ assert isinstance( - inputs, tuple - ), "Inputs should be wrapped in a tuple prior to preparing for gradients" + inputs, (tuple, list) + ), "Inputs should be wrapped in a tuple or list prior to preparing for gradients" grad_required = [] for index, input in enumerate(inputs): assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor" @@ -298,9 +298,9 @@ def hook_wrapper(original_module): # pyre-fixme[2]: Parameter must be annotated. def forward_hook(module, inp, out=None): eval_tsrs = inp if attribute_to_layer_input else out - is_eval_tuple = isinstance(eval_tsrs, tuple) + is_eval_tuple_or_list = isinstance(eval_tsrs, (tuple, list)) - if not is_eval_tuple: + if not is_eval_tuple_or_list: eval_tsrs = (eval_tsrs,) if require_layer_grads: apply_gradient_requirements(eval_tsrs, warn=False) @@ -310,11 +310,16 @@ def forward_hook(module, inp, out=None): # otherwise `backward()` on the last output layer won't execute. if forward_hook_with_return: saved_layer[original_module][eval_tsrs[0].device] = eval_tsrs - eval_tsrs_to_return = tuple( - eval_tsr.clone() for eval_tsr in eval_tsrs - ) - if not is_eval_tuple: - eval_tsrs_to_return = eval_tsrs_to_return[0] + if not is_eval_tuple_or_list: + eval_tsrs_to_return = eval_tsrs[0].clone() + elif isinstance(eval_tsrs, list): + eval_tsrs_to_return = [ + eval_tsr.clone() for eval_tsr in eval_tsrs + ] + else: + eval_tsrs_to_return = tuple( + eval_tsr.clone() for eval_tsr in eval_tsrs + ) return eval_tsrs_to_return else: saved_layer[original_module][eval_tsrs[0].device] = tuple( diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index efa2c5456..a36168b03 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -432,6 +432,8 @@ def __init__( self.linear3.weight = nn.Parameter(torch.ones(2, 4)) self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) + self.list_output_layer = PassThroughLayerOutput() + self.int_layer = PassThroughLayerOutput() # sample layer with an int output @no_type_check @@ -452,11 +454,13 @@ def forward( relu_out = self.relu(lin1_out) lin2_out = self.linear2(relu_out) + list_out = self.list_output_layer([nn.Linear(2, 2)(lin2_out) for _ in range(2)]) + resized_list_out = torch.cat(list_out, dim=1) lin3_out = self.linear3(lin1_out_alt) int_output = self.int_layer(lin3_out.to(torch.int64)) - output_tensors = torch.cat((lin2_out, int_output), dim=1) + output_tensors = torch.cat((resized_list_out, int_output), dim=1) return ( output_tensors