Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of Detectron2 models #24

Closed
justlike-prog opened this issue Mar 22, 2022 · 1 comment
Closed

Support of Detectron2 models #24

justlike-prog opened this issue Mar 22, 2022 · 1 comment
Labels
example Example of using the nebullvm library question Further information is requested

Comments

@justlike-prog
Copy link

For now Detectron2 models don't work out of the box because of the different input format.

It would be nice to have an example on how to fix that.

Thanks!

@diegofiori diegofiori added the question Further information is requested label Mar 23, 2022
@diegofiori diegofiori self-assigned this Mar 23, 2022
@diegofiori
Copy link
Collaborator

Hello @justlike-prog!

I gave a look at the Detectron2 source code and I found an easy way to implement a workaround for optimizing it. In Detectron2 the whole computation is made by the Resnet-based backbone, so we can simply optimize the backbone for already getting good results.

The problem was that the backbone gives as output a dictionary and we need to map it to a tuple for being consistent with the Nebullvm API. I suggest defining two "wrapper" classes (one for the non-optimized and one for the optimized model) and then using them for running the Detectron2 optimized model.

Let's define the classes as

class BaseModelWrapper(torch.nn.Module):
    def __init__(self, core_model, output_dict):
        super().__init__()
        self.core_model = core_model
        self.output_names = [key for key in output_dict.keys()]
    
    def forward(self, *args, **kwargs):
        res = self.core_model(*args, **kwargs)
        return tuple(res[key] for key in self.output_names)


class OptimizedWrapper(torch.nn.Module):
    def __init__(self, optimized_model, output_keys):
        super().__init__()
        self.optimized_model = optimized_model
        self.output_keys = output_keys
    
    def forward(self, *args):
        res = self.optimized_model(*args)
        return {key: value for key, value in zip(self.output_keys, res)}

Then you can simply run the following code

import copy
from detectron2 import model_zoo 
from nebullvm import optimize_torch_model

config_path = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
model = model_zoo.get(config_path, trained=True) 
model.eval()
model_backbone = copy.deepcopy(model.backbone)
res = model_backbone(torch.randn(1, 3, 256, 256))  # needed for getting the output_keys
backbone_wrapper = BaseModelWrapper(model_backbone, res)
optimized_model = optimize_torch_model(backbone_wrapper, batch_size=1, input_sizes=[(3, 256, 256)], save_dir="./")
optimized_backbone = OptimizedWrapper(optimized_model, backbone_wrapper.output_names)
# finally replace the old backbone with the optimised one
model.backbone = optimized_backbone

If you need dynamic input shapes you'd need to add a few more arguments to the optimize_torch_model function. Please give a look at issue #26 for further info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
example Example of using the nebullvm library question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants