In [22]:
import torch
from torchvision.models import resnet50, ResNet50_Weights

from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork

import time

# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50(weights=ResNet50_Weights.DEFAULT)
train_nodes, eval_nodes = get_graph_node_names(resnet50())

# print(eval_nodes)

# # The lists returned, are the names of all the graph nodes (in order of
# # execution) for the input model traced in train mode and in eval mode
# # respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# # for this example. But if the model contains control flow that's dependent
# # on the training mode, they may be different.

# # To specify the nodes you want to extract, you could select the final node
# # that appears in each of the main layers:
return_nodes = {
    # node_name: user-specified key for output dict
    'fc': 'fc'
    # 'layer2.3.relu_2': 'layer2',
    # 'layer3.5.relu_2': 'layer3',
    # 'layer4.2.relu_2': 'layer4',
}

# # But `create_feature_extractor` can also accept truncated node specifications
# # like "layer1", as it will just pick the last node that's a descendent of
# # of the specification. (Tip: be careful with this, especially when a layer
# # has multiple outputs. It's not always guaranteed that the last operation
# # performed is the one that corresponds to the output you desire. You should
# # consult the source code for the input model to confirm.)
# return_nodes = {
#     'layer1': 'layer1',
#     'layer2': 'layer2',
#     'layer3': 'layer3',
#     'layer4': 'layer4',
# }

# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
#     'layer1': output of layer 1,
#     'layer2': output of layer 2,
#     'layer3': output of layer 3,
#     'layer4': output of layer 4,
# }

# # Let's put all that together to wrap resnet50 with MaskRCNN

# # MaskRCNN requires a backbone with an attached FPN
# class Resnet50WithFPN(torch.nn.Module):
#     def __init__(self):
#         super(Resnet50WithFPN, self).__init__()
#         # Get a resnet50 backbone
#         m = resnet50()
#         # Extract 4 main layers (note: MaskRCNN needs this particular name
#         # mapping for return nodes)
#         self.body = create_feature_extractor(
#             m, return_nodes={f'layer{k}': str(v)
#                              for v, k in enumerate([1, 2, 3, 4])})
#         # Dry run to get number of channels for FPN
#         inp = torch.randn(2, 3, 224, 224)
#         with torch.no_grad():
#             out = self.body(inp)
#         in_channels_list = [o.shape[1] for o in out.values()]
#         # Build FPN
#         self.out_channels = 256
#         self.fpn = FeaturePyramidNetwork(
#             in_channels_list, out_channels=self.out_channels,
#             extra_blocks=LastLevelMaxPool())

#     def forward(self, x):
#         x = self.body(x)
#         x = self.fpn(x)
#         return x


# # Now we can build our model!
# model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /home/kesharaw/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100.0%


In [23]:
model = create_feature_extractor(m, return_nodes=return_nodes)


img = torch.Tensor(3, 224, 224).normal_().unsqueeze(0) # random image

start_t = time.time_ns()

feature = model(img) # run through the model

end_t = time.time_ns()

print(feature['fc'], (end_t-start_t)/1e6)

tensor([[-4.9084e-02, -1.2750e-01, -1.4604e-01, -9.2053e-02, -9.5742e-02,
         -2.3329e-03, -7.6273e-02, -4.1332e-02, -1.0239e-01,  1.1429e-01,
          1.1852e-01, -2.2967e-02,  1.4970e-02,  1.8402e-01, -9.9965e-02,
         -7.3689e-02, -6.3220e-02, -5.1381e-02,  9.9736e-02,  1.6258e-01,
          1.0990e-01,  1.8339e-01,  1.8290e-03,  2.5504e-01,  2.7385e-02,
         -1.7454e-02,  9.8643e-03,  2.8584e-02, -1.2315e-01, -2.2325e-02,
         -2.0508e-01,  3.6078e-02, -1.7376e-02, -1.0771e-01, -4.0003e-02,
         -3.3293e-01, -1.5244e-01, -1.6180e-01,  2.2616e-01, -1.7417e-01,
         -7.8261e-02, -1.6842e-01, -1.2735e-01, -1.9618e-01, -1.0334e-01,
         -6.5736e-02, -1.4554e-01, -1.0985e-01, -4.2087e-02, -1.5933e-01,
         -7.6268e-02, -1.1801e-01, -2.2458e-02, -9.8565e-02, -3.1209e-02,
         -1.2528e-01, -1.9194e-01, -1.7670e-01,  1.2776e-01, -1.3986e-01,
         -4.3308e-02,  1.2909e-02,  6.8254e-03, -2.4095e-02, -6.9133e-02,
         -1.6542e-01, -5.2536e-03, -9.