In [27]:
import torch
from torch import nn
from torchvision import models

In [41]:
class HookWrapper(nn.Module):
    '''
        Wrapper class for Forward/Backward feature map extraction.

        - Usage -
        1) Make the instance of this class with the model and target layers.
        2) Forward/Backward it.
        3) Call get_features() will return the feature maps of previously forward/backwarded input.
        4) Back to 2).
    '''
    def __init__(self, model, target_layers):
        super(HookWrapper,self).__init__()
        self.model = model
        self.target_layers = target_layers
        self.features = [] # Size will be 4 after hook

        for name, module in model.named_children():
            if name in target_layers:
                module.register_forward_hook(self._extraction_fn)

    def _extraction_fn(self, module, input, output):
        # print(f"Test : {output.shape}")
        self.features.append(output)

    def forward(self, x):
        return self.model(x)
    
    def get_features(self): # Return feature list and make it empty.
        tmp = self.features
        self.features = []
        return tmp

target_layers = ['layer1', 'layer2', 'layer3', 'layer4']

net = models.resnet18(pretrained=True)
hook_net = HookWrapper(net, target_layers)

In [45]:
x = torch.randn((1,3,224,224))

pred = hook_net(x)
features = hook_net.get_features()
print(len(features))

4
