Skip to content
Empty file modified captum/optim/__init__.py
100755 → 100644
Empty file.
20 changes: 16 additions & 4 deletions captum/optim/_core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
) -> None:
r"""
Args:

model (nn.Module): The reference to PyTorch model instance.
input_param (nn.Module, optional): A module that generates an input,
consumed by the model.
Expand All @@ -71,6 +72,7 @@ def __init__(

def loss(self) -> torch.Tensor:
r"""Compute loss value for current iteration.

Returns:
*tensor* representing **loss**:
- **loss** (*tensor*):
Expand Down Expand Up @@ -115,18 +117,26 @@ def optimize(
lr: float = 0.025,
) -> torch.Tensor:
r"""Optimize input based on loss function and objectives.

Args:

stop_criteria (StopCriteria, optional): A function that is called
every iteration and returns a bool that determines whether
to stop the optimization.
See captum.optim.typing.StopCriteria for details.
optimizer (Optimizer, optional): An torch.optim.Optimizer used to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the optimize method loss_summarize_fn and lr aren't documented. Do you mind adding documentation their too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'll add the documentation for those variables!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProGamerGov, the documentation for losses we will add in a separate PR ?https://github.com/pytorch/captum/blob/optim-wip/captum/optim/_core/loss.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK Yes, we'll add the loss documentation in a separate PR!

optimize the input based on the loss function.
loss_summarize_fn (Callable, optional): The function to use for summarizing
tensor outputs from loss functions.
Default: default_loss_summarize
lr: (float, optional): If no optimizer is given, then lr is used as the
learning rate for the Adam optimizer.
Default: 0.025

Returns:
*list* of *np.arrays* representing the **history**:
- **history** (*list*):
A list of loss values per iteration.
Length of the list corresponds to the number of iterations
history (torch.Tensor): A stack of loss values per iteration. The size
of the dimension on which loss values are stacked corresponds to
the number of iterations.
"""
stop_criteria = stop_criteria or n_steps(512)
optimizer = optimizer or optim.Adam(self.parameters(), lr=lr)
Expand All @@ -150,10 +160,12 @@ def optimize(

def n_steps(n: int, show_progress: bool = True) -> StopCriteria:
"""StopCriteria generator that uses number of steps as a stop criteria.

Args:
n (int): Number of steps to run optimization.
show_progress (bool, optional): Whether or not to show progress bar.
Default: True

Returns:
*StopCriteria* callable
"""
Expand Down
54 changes: 47 additions & 7 deletions captum/optim/_core/output_hook.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,37 @@
from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType


class ModuleReuseException(Exception):
pass


class ModuleOutputsHook:
def __init__(self, target_modules: Iterable[nn.Module]) -> None:
"""
Args:

target_modules (Iterable of nn.Module): A list of nn.Module targets.
"""
self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None)
self.hooks = [
module.register_forward_hook(self._forward_hook())
for module in target_modules
]

def _reset_outputs(self) -> None:
"""
Delete captured activations.
"""
self.outputs = dict.fromkeys(self.outputs.keys(), None)

@property
def is_ready(self) -> bool:
return all(value is not None for value in self.outputs.values())

def _forward_hook(self) -> Callable:
"""
Return the forward_hook function.

Returns:
forward_hook (Callable): The forward_hook function.
"""

def forward_hook(
module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor
) -> None:
Expand All @@ -49,6 +60,12 @@ def forward_hook(
return forward_hook

def consume_outputs(self) -> ModuleOutputMapping:
"""
Collect target activations and return them.

Returns:
outputs (ModuleOutputMapping): The captured outputs.
"""
if not self.is_ready:
warn(
"Consume captured outputs, but not all requested target outputs "
Expand All @@ -63,11 +80,16 @@ def targets(self) -> Iterable[nn.Module]:
return self.outputs.keys()

def remove_hooks(self) -> None:
"""
Remove hooks.
"""
for hook in self.hooks:
hook.remove()

def __del__(self) -> None:
# print(f"DEL HOOKS!: {list(self.outputs.keys())}")
"""
Ensure that using 'del' properly deletes hooks.
"""
self.remove_hooks()


Expand All @@ -77,16 +99,34 @@ class ActivationFetcher:
"""

def __init__(self, model: nn.Module, targets: Iterable[nn.Module]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: documentation for init arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll move all the init arguments from class descriptions to the class' init functions.

"""
Args:

model (nn.Module): The reference to PyTorch model instance.
targets (nn.Module or list of nn.Module): The target layers to
collect activations from.
"""
super(ActivationFetcher, self).__init__()
self.model = model
self.layers = ModuleOutputsHook(targets)

def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping:
"""
Args:

input_t (tensor or tuple of tensors, optional): The input to use
with the specified model.

Returns:
activations_dict: An dict containing the collected activations. The keys
for the returned dictionary are the target layers.
"""

try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.model(input_t)
activations = self.layers.consume_outputs()
activations_dict = self.layers.consume_outputs()
finally:
self.layers.remove_hooks()
return activations
return activations_dict
Loading