Skip to content

Conversation

@amaarora
Copy link
Contributor

@amaarora amaarora commented Mar 26, 2021

Refer here for example usage https://gist.github.com/amaarora/c562da34b95d97f8254960bdca6a12d1.

As this is my first PR, please let me if I haven't used the correct formatting etc and can contribute towards creating a contributing.md to set the standard rules for PRs.

Thanks Ross!

@rwightman
Copy link
Collaborator

@amaarora thanks for the addition, I don't have any issues re style/guidlines.

On the impl, this will match the originals for the resnet/nf_resnet modles but not quite the same as nfnet. For nfnet they also tap the end of the residual path (before the addition/skipinit). That requires hooking specific sub-layers by name, not just the block module.

I wonder if it makes sense to use layer_names (for the actual full layer name in the hierarchy, and possibly layer_types if that is desired as well)... below I show an example using layer names w/ fnmatch to grab the stem, the block output, and the residual path

model = timm.create_model('nfnet_f0s')
p = ('stem', 'stages.?.?', 'stages.?.?.attn_last')
for n, m in model.named_modules():
...     if any([fnmatch.fnmatch(n, pt) for pt in p]):
...         print(n)
... 
stem
stages.0.0
stages.0.0.attn_last
stages.1.0
stages.1.0.attn_last
stages.1.1
stages.1.1.attn_last
stages.2.0
stages.2.0.attn_last
stages.2.1
stages.2.1.attn_last
stages.2.2
stages.2.2.attn_last
stages.2.3
stages.2.3.attn_last
stages.2.4
stages.2.4.attn_last
stages.2.5
stages.2.5.attn_last
stages.3.0
stages.3.0.attn_last
stages.3.1
stages.3.1.attn_last
stages.3.2
stages.3.2.attn_last

@amaarora
Copy link
Contributor Author

Thanks @rwightman, I've implemented a version of your approach but the matching pattern for residual block will vary across models?

For example for nfnet_f0 it is, 'stages.?.?.attn_last'
Whereas for resnet50 this is layer?.?.bn2

Notice the difference in the . placement after initial stages or layer in the two. This pattern will vary further for RegNet or other models. So I wonder, if it will be easier to get the user to provide the matching pattern instead of layer_name, layer_type ? I can't think of an easy abstraction to solve this and wondering if you have any further ideas? :)

@amaarora
Copy link
Contributor Author

amaarora commented Mar 28, 2021

I suggest we update the register_hook to:

    def register_hook(self, hook_fn_loc, hook_fn):
        for name, module in self.model.named_modules():
            if not fnmatch.fnmatch(name, hook_fn_loc):
                continue
            module.register_forward_hook(self._create_hook(hook_fn))

Where users provide the matching pattern hook_fn_loc

@amaarora
Copy link
Contributor Author

@rwightman I've updated PR. Feel that letting the user provide pattern is the best and most flexible way to keep the ActivationStatsHook generic to work with all timm models.

Here is an example that plots SPPs including the last layer on residual branch: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950

avg_ch_var and avg_ch_var_residual are the same function but repeated because stats inside ActivationStatsHook get stored in the dictionary based on function __name__.

@rwightman
Copy link
Collaborator

rwightman commented Mar 29, 2021

@rwightman I've updated PR. Feel that letting the user provide pattern is the best and most flexible way to keep the ActivationStatsHook generic to work with all timm models.

@amaarora Agreed, this way allows it to be used with any model. Since it does require a bit of knowledge about the model & how hooks work, it might be good to expand the docstring for your ActivationStatsHook to include a link to your gist as a concrete and immediately accessible example.

@amaarora
Copy link
Contributor Author

it might be good to expand the docstring for your ActivationStatsHook to include a link to your gist as a concrete and immediately accessible example

Done :)

@rwightman rwightman merged commit 2319cbb into huggingface:master Mar 29, 2021
guoriyue pushed a commit to guoriyue/pytorch-image-models that referenced this pull request May 24, 2024
Add `ActivationStatsHook`  to allow extracting activation stats for Signal Propogation Plots
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants