-
Notifications
You must be signed in to change notification settings - Fork 543
Optim-wip: Add a ton of missing docs #571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
f9ca6c1
Add a ton of missing docs
ProGamerGov ab19efb
Add docs for NumPy helpers
ProGamerGov 395ab63
Add image docs
ProGamerGov 5a5b285
Add more docs
ProGamerGov 4e65857
Merge remote-tracking branch 'upstream/optim-wip' into optim-wip
ProGamerGov 66985f8
Improve docs
ProGamerGov f5a755d
Add missing optional to doc
ProGamerGov e020ef7
Update docs to reflect new changes
ProGamerGov f67d534
Merge remote-tracking branch 'upstream/optim-wip' into optim-wip
ProGamerGov 7c6ef63
Changes based on feedback
ProGamerGov 1a4c0bf
Fix Flake8
ProGamerGov bdc3f17
Add missing 'optional's to docs
ProGamerGov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,26 +8,37 @@ | |
from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType | ||
|
||
|
||
class ModuleReuseException(Exception): | ||
pass | ||
|
||
|
||
class ModuleOutputsHook: | ||
def __init__(self, target_modules: Iterable[nn.Module]) -> None: | ||
""" | ||
Args: | ||
|
||
target_modules (Iterable of nn.Module): A list of nn.Module targets. | ||
""" | ||
self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None) | ||
self.hooks = [ | ||
module.register_forward_hook(self._forward_hook()) | ||
for module in target_modules | ||
] | ||
|
||
def _reset_outputs(self) -> None: | ||
""" | ||
Delete captured activations. | ||
""" | ||
self.outputs = dict.fromkeys(self.outputs.keys(), None) | ||
|
||
@property | ||
def is_ready(self) -> bool: | ||
return all(value is not None for value in self.outputs.values()) | ||
|
||
def _forward_hook(self) -> Callable: | ||
""" | ||
Return the forward_hook function. | ||
|
||
Returns: | ||
forward_hook (Callable): The forward_hook function. | ||
""" | ||
|
||
def forward_hook( | ||
module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor | ||
) -> None: | ||
|
@@ -49,6 +60,12 @@ def forward_hook( | |
return forward_hook | ||
|
||
def consume_outputs(self) -> ModuleOutputMapping: | ||
""" | ||
Collect target activations and return them. | ||
|
||
Returns: | ||
outputs (ModuleOutputMapping): The captured outputs. | ||
""" | ||
if not self.is_ready: | ||
warn( | ||
"Consume captured outputs, but not all requested target outputs " | ||
|
@@ -63,11 +80,16 @@ def targets(self) -> Iterable[nn.Module]: | |
return self.outputs.keys() | ||
|
||
def remove_hooks(self) -> None: | ||
""" | ||
Remove hooks. | ||
""" | ||
for hook in self.hooks: | ||
hook.remove() | ||
|
||
def __del__(self) -> None: | ||
# print(f"DEL HOOKS!: {list(self.outputs.keys())}") | ||
""" | ||
Ensure that using 'del' properly deletes hooks. | ||
""" | ||
self.remove_hooks() | ||
|
||
|
||
|
@@ -77,16 +99,34 @@ class ActivationFetcher: | |
""" | ||
|
||
def __init__(self, model: nn.Module, targets: Iterable[nn.Module]) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: documentation for init arguments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll move all the init arguments from class descriptions to the class' init functions. |
||
""" | ||
Args: | ||
|
||
model (nn.Module): The reference to PyTorch model instance. | ||
targets (nn.Module or list of nn.Module): The target layers to | ||
collect activations from. | ||
""" | ||
super(ActivationFetcher, self).__init__() | ||
self.model = model | ||
self.layers = ModuleOutputsHook(targets) | ||
|
||
def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping: | ||
""" | ||
Args: | ||
|
||
input_t (tensor or tuple of tensors, optional): The input to use | ||
with the specified model. | ||
|
||
Returns: | ||
activations_dict: An dict containing the collected activations. The keys | ||
for the returned dictionary are the target layers. | ||
""" | ||
|
||
try: | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
self.model(input_t) | ||
activations = self.layers.consume_outputs() | ||
activations_dict = self.layers.consume_outputs() | ||
finally: | ||
self.layers.remove_hooks() | ||
return activations | ||
return activations_dict |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the
optimize
methodloss_summarize_fn
andlr
aren't documented. Do you mind adding documentation their too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'll add the documentation for those variables!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProGamerGov, the documentation for losses we will add in a separate PR ?https://github.com/pytorch/captum/blob/optim-wip/captum/optim/_core/loss.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK Yes, we'll add the loss documentation in a separate PR!