In [1]:
import torch
from torchvision.models import resnet50, ResNet50_Weights

from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork

import time

# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50(weights=ResNet50_Weights.DEFAULT)
train_nodes, eval_nodes = get_graph_node_names(resnet50())

# print(eval_nodes)

# # The lists returned, are the names of all the graph nodes (in order of
# # execution) for the input model traced in train mode and in eval mode
# # respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# # for this example. But if the model contains control flow that's dependent
# # on the training mode, they may be different.

# # To specify the nodes you want to extract, you could select the final node
# # that appears in each of the main layers:
return_nodes = {
    # node_name: user-specified key for output dict
    'fc': 'fc'
    # 'layer2.3.relu_2': 'layer2',
    # 'layer3.5.relu_2': 'layer3',
    # 'layer4.2.relu_2': 'layer4',
}

# # But `create_feature_extractor` can also accept truncated node specifications
# # like "layer1", as it will just pick the last node that's a descendent of
# # of the specification. (Tip: be careful with this, especially when a layer
# # has multiple outputs. It's not always guaranteed that the last operation
# # performed is the one that corresponds to the output you desire. You should
# # consult the source code for the input model to confirm.)
# return_nodes = {
#     'layer1': 'layer1',
#     'layer2': 'layer2',
#     'layer3': 'layer3',
#     'layer4': 'layer4',
# }

# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
#     'layer1': output of layer 1,
#     'layer2': output of layer 2,
#     'layer3': output of layer 3,
#     'layer4': output of layer 4,
# }

# # Let's put all that together to wrap resnet50 with MaskRCNN

# # MaskRCNN requires a backbone with an attached FPN
# class Resnet50WithFPN(torch.nn.Module):
#     def __init__(self):
#         super(Resnet50WithFPN, self).__init__()
#         # Get a resnet50 backbone
#         m = resnet50()
#         # Extract 4 main layers (note: MaskRCNN needs this particular name
#         # mapping for return nodes)
#         self.body = create_feature_extractor(
#             m, return_nodes={f'layer{k}': str(v)
#                              for v, k in enumerate([1, 2, 3, 4])})
#         # Dry run to get number of channels for FPN
#         inp = torch.randn(2, 3, 224, 224)
#         with torch.no_grad():
#             out = self.body(inp)
#         in_channels_list = [o.shape[1] for o in out.values()]
#         # Build FPN
#         self.out_channels = 256
#         self.fpn = FeaturePyramidNetwork(
#             in_channels_list, out_channels=self.out_channels,
#             extra_blocks=LastLevelMaxPool())

#     def forward(self, x):
#         x = self.body(x)
#         x = self.fpn(x)
#         return x


# # Now we can build our model!
# model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()

In [5]:
model = create_feature_extractor(m, return_nodes=return_nodes)


img = torch.Tensor(3, 224, 224).normal_().unsqueeze(0) # random image

start_t = time.time_ns()

feature = model(img) # run through the model

end_t = time.time_ns()

print(feature['fc'], (end_t-start_t)/1e6)

tensor([[ 1.2311e-01,  2.6627e-02, -2.0454e-01, -7.1961e-02, -1.6311e-01,
         -9.7348e-02, -1.3111e-01,  1.5864e-02, -1.2850e-01, -4.6101e-02,
         -1.2628e-01,  5.5929e-03, -7.5287e-02, -4.2988e-02, -1.5961e-01,
         -2.8518e-01, -1.7445e-01, -1.3087e-01,  2.8835e-02, -1.3057e-01,
         -4.9089e-02, -5.6938e-02, -7.7273e-02, -1.5005e-01, -2.6341e-01,
         -1.2792e-01,  4.7002e-02, -1.1260e-01, -1.9666e-01,  1.3363e-01,
          5.0806e-02, -7.8686e-02, -2.5284e-02, -6.7500e-02, -8.9390e-03,
          1.2266e-01, -2.7628e-02, -1.8756e-01, -1.4294e-01,  2.6419e-04,
         -1.8188e-01, -1.4368e-01, -1.3017e-01,  2.8223e-02, -8.1975e-02,
         -2.4229e-01,  1.0547e-01,  9.4429e-02, -2.9852e-02, -5.3831e-02,
         -1.1238e-01,  6.3411e-02,  5.2182e-02, -1.5774e-01, -2.2411e-01,
         -4.7246e-02, -1.9338e-01, -2.5123e-01,  2.4619e-02, -2.8040e-02,
          4.7881e-02, -5.9730e-02, -3.0478e-01,  1.1754e-01, -4.2383e-02,
         -2.2351e-01, -4.2080e-02, -3.