Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make target_layer also a nn.Module #81

Closed
FrancescoSaverioZuppichini opened this issue Aug 18, 2021 · 7 comments 路 Fixed by #83
Closed

Make target_layer also a nn.Module #81

FrancescoSaverioZuppichini opened this issue Aug 18, 2021 · 7 comments 路 Fixed by #83
Assignees
Labels
module: methods Related to torchcam.methods question Further information is requested type: improvement New feature or request

Comments

@FrancescoSaverioZuppichini

馃殌 Feature

Hello there,

I would like to pass an nn.Module instance to any *Cam constructor but currently, I can only use 'str'.

Motivation

Well, if your model has a good design it is much easier to pass a reference than the key.

Thanks!

Francesco

@frgfm
Copy link
Owner

frgfm commented Aug 23, 2021

Hi @FrancescoSaverioZuppichini 馃憢

Thanks for the suggestion! I thought about it a while ago, but couldn't really see cases where it's better on the user-side than passing the string 馃

Currently

# Define your model
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval()

# Set your CAM extractor
from torchcam.cams import SmoothGradCAMpp
cam_extractor = SmoothGradCAMpp(model, 'layer4')

and with your suggestion

# Define your model
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval()

# Set your CAM extractor
from torchcam.cams import SmoothGradCAMpp
cam_extractor = SmoothGradCAMpp(model, model.layer4)

I agree that passing a reference rather than a string we'll get a reference from later on is cleaner from the programming point of view but I can see at least one drawback: if you pass the target layer as a reference to another model by mistake, the CAM won't work. Also this introduces some redundancy in arguments: on top of the target layer, the whole model has to be passed (to clear gradients, and for forwards in some specific CAM methods).

Could you tell me in which cases the string argument is not enough on your end please? 馃檹

@frgfm frgfm added module: methods Related to torchcam.methods question Further information is requested labels Aug 23, 2021
@FrancescoSaverioZuppichini
Copy link
Author

Hi @frgfm ,

Thank you for your reply. I hope you are doing great ;) IMHO Passing reference is always better than passing a string.

First of all, if I pass a string you'll need to trace the model to find out where it is. Secondly, maybe I have no idea what my "key" is. What if I want to reference by index a list of layers? I cannot do that, by reference it is as easy as model.layers[idx] . Or what if I have a very nested model, something like model.encoder.blocks[1].head.projection, by reference is dead easy to pass what you need, but by string? What about collisions? If I have two layers called in the same way, as it is usually done in practice, the current implementation will not be able to understand which one is the correct one

You can always support strings keys by having something .from_key(model, key: str)

It is not a big deal, I would love to use your *Cams in my projects but I can just joink the code and change it

Thank you again!

Cheers,

Francesco

@frgfm
Copy link
Owner

frgfm commented Sep 2, 2021

Being a python perfectionist, I obviously agree with you that passing a ref is better than a string 馃憣

However here there are a few things to consider:

  • the string constraint makes it mandatory for a user to target a layer inside the model you observe. If you pass a reference, it's gonna be quite tricky to check for errors (if that person passes a model, and a reference to a layer of another model). Agreed, this will not happen every day, but at least, currently, the raised error will easily point the user towards the issue
  • there is a lack of documentation about how to specify the sublayer of a model, but your concerns about indexing should be lifted with subpart.1 pointing towards the second item of the nn.Sequential called subpart in your model for instance since fix: Support CAM for intermediate layers聽#21 馃槃
  • regarding collisions, I'm not sure I understand your point? There is no foreseeable improvement about collision switching from string to reference 馃 If a given layer is called several times, since TorchCAM uses pytorch hooks, there is no way (I mean with what's going on under the hood) to pinpoint the i-th call. If you meant that they are named the same way, this is not possible in a torch.nn.Module as the second time you assign a value will override the reference.

I chose the string interface initially because it matches the way state_dict are created for a given Module. In both your examples, you would get the exact same results with a string argument by passing "layers.idx" and "encoder.blocks.1.head.projection" respectively. So I'm really curious about why the string argument is a limitation usage-wise, I'd really like to understand so that I can improve the library support 馃檹

Anyway, if this is something better for the community, I'm happy to add support for this!

@FrancescoSaverioZuppichini
Copy link
Author

Hey @frgfm , thank you for your reply and the amazing discussion :) I love your enthusiasm!

  • I agree but python is not typed, a user can pass everything so yeah I don't know :)
  • Got it
  • This is my fault, I wasn't clear. So about collisions, you are correct!

But I was thinking, if you pass a string then you need to trace the module to find it, would it be more convenient to pass a reference?

Image the following situation, I have a model, I know which layer I want to index, why should I waste computation doing one forward pass? I can just pass my reference and boom we hook the hook :) What do you think?

Thank you again!

Francesco :)

@frgfm frgfm linked a pull request Sep 12, 2021 that will close this issue
5 tasks
@frgfm
Copy link
Owner

frgfm commented Sep 21, 2021

My apologies Francesco, I've been quite busy with other projects lately!

I'm not sure I follow the reasoning about wasted computation? 馃槄
Currently, there are two cases:

  1. A value for target_layer is provided: the layer is retrieved in O(1) complexity since it's just accessing the naming dictionary
  2. No value is provided: one forward is "wasted" to retrieve the best candidate layer

Now if reference passing is added:

  1. Value provided: the only difference is that the dictionary doesn't have to be built, but it's a dummy call of torch.nn.Module.named_modules so I would argue it really makes no difference
  2. No value is provided: same process as before

In my opinion, with a reference is provided, it would be even somehow important to do a dummy inference to ensure the module indeed belongs to this model 馃槄 I guess a warning could be thrown when trying to compute the cam if the reference was incorrect (but that would have been caught in the very constructor with the string arg)

My best suggestion would be then:

  • adding support of reference
  • in the constructor, we still build the naming dictionary, and go through this to ensure that the reference is one of the values (to avoid issues later on)

What do you think? :)

@FrancescoSaverioZuppichini
Copy link
Author

hey @frgfm , what's up! Recently PyTorch releases a new way to extract features. I was thinking you may find it interesting :)

@frgfm
Copy link
Owner

frgfm commented Jan 6, 2022

Hi @FrancescoSaverioZuppichini 馃憢

Sorry about the late reply! Yup, I saw that in the previous release notes, and wanted to try it out. I'll see if that can help this project 馃憤

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: methods Related to torchcam.methods question Further information is requested type: improvement New feature or request
Projects
None yet
2 participants