Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Tracing / Scripting #138

Closed
wants to merge 28 commits into from
Closed

[WIP] Tracing / Scripting #138

wants to merge 28 commits into from

Conversation

@t-vi
Copy link

@t-vi t-vi commented Nov 10, 2018

With this patch you can get a traced/scripted MaskRCNN model. To facilitate discussion - maybe also to have a test-case for the remaining features to be wanted in the jit - I put this out in it's current state, rather than working in private towards a mergeable patch.
I appreciate your feedback but note that it's not quite ready yet.

We have:

  • demo/trace_model.py produces a torch script model by a mix of tracing/scripting,
  • demo/traced_model.cpp uses the model in C++.

So all of this is very raw, and there are classes of hacks described in the issue #27, in particular

  • call forward directly all the time to avoid errors about tracing tensor/non-tuple functions, but my understanding is that there might be a way to have more structured types with the jit soonish,
  • some script functions with funny indirections to make jit work better (though some might be unneeded by the awesome work of the PyTorch team),
  • paths with alternative calculations, in particular to avoid inplace operations,
  • a few custom ops, e.g. to avoid untraceable variable length loops,

Lots of Todos:

  • there are warnings to be investigated, in particularly for loops and boxlist sizes
  • it needs some PyTorch JIT fixes that are in master (all there, thanks!)
  • clean up model, including reverting unneeded changes,
  • clean up cmake, e.g. proper detection of OpenCV,
  • invoke cmake / build the custom ops from setup.py,
  • try a different model/config,
  • maybe script more to be more flexible with sizes,
  • have a leaner pretrained model.

As for my use case: It also would work on Android if PyTorch was there yet.

As this is very much work in progress, here are some quick notes:
- demo/trace_model.py has the current state

Lots of todos:
- there are warnings to be investigated, in particularly for loops and
  boxlist sizes
- it needs some PyTorch JIT fixes
- clean up, including reverting unneeded changes
- round off the displaying
- do a C++ app
- try a different model/config
@fmassa

This comment has been minimized.

Copy link

@fmassa fmassa commented on maskrcnn_benchmark/modeling/poolers.py in 84af4e9 Nov 7, 2018

I think the indentation in here is off

This comment has been minimized.

Copy link
Owner Author

@t-vi t-vi replied Nov 7, 2018

Thanks, yes!

@facebook-github-bot
Copy link

@facebook-github-bot facebook-github-bot commented Nov 10, 2018

Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file. In order for us to review and merge your code, please sign up at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need the corporate CLA signed.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@facebook-github-bot
Copy link

@facebook-github-bot facebook-github-bot commented Nov 10, 2018

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

@Eric-Zhang1990
Copy link

@Eric-Zhang1990 Eric-Zhang1990 commented Nov 12, 2018

@t-vi Can you tell me why this error shows when I am running your demo trace_model.py ? Thanks. The error is: OSError: /home/eric/Disk100G/githubProject/maskrcnn-benchmark/maskrcnn_benchmark/csrc/custom_ops/libmaskrcnn_benchmark_customops.so: undefined symbol: _ZN5torch3jit8ListType9ofTensorsEv

@t-vi
Copy link
Author

@t-vi t-vi commented Nov 12, 2018

@Eric-Zhang1990 that symbol (aka torch::jit::ListType::ofTensors()) looks like it might be in libtorch and the libmaskrcnn_benchmark_customops.so should be linked to libtorch if all goes OK. maybe the libtorch you built with isn't the one you use? With the lightning speed of JIT development necessitates my experience is that having matching versions during build and invocation often helps.

@xxradon
Copy link

@xxradon xxradon commented Nov 12, 2018

@t-vi Wondelful work you had done,can you give a tutorial to reproduce your work?I met nearly same error as @Eric-Zhang1990 did. Thanks.

OSError: /home/shining/Projects/github-projects/pytorch-project/maskrcnn-benchmark/maskrcnn_benchmark/libmaskrcnn_benchmark_customops.so: undefined symbol: _ZN3c108demangleEPKc

@t-vi
Copy link
Author

@t-vi t-vi commented Nov 12, 2018

@xxradon I think it's premature to expect this to work without bumps, unfortunately. But of course, I appreciate that you are trying and I'll be very happy for suggestions how to improve the build process.

One key aspect is that you have to make 100% sure that your libtorch exactly matches the one you used to build and in Python and the headers need to be the right ones, too. Try pointing to it with LD_LIBRARY_PATH to torch/lib or so.
In your specific case: For me
objdump -T maskrcnn_benchmark/csrc/custom_ops/build/libmaskrcnn_benchmark_customops.so | c++filt | grep demang
seems to indicate that the lib doesn't require it directly, so apparently whichever libtorch is used doesn't find a matching libc10/libcaffe2...

@xxradon
Copy link

@xxradon xxradon commented Nov 12, 2018

@t-vi Thanks , your suggestion is right,I can run trace_model.py without mistake and got the end_to_end_model.pt file.And when I tested trace_model.cpp,built was ok,but when I ran the demo there are some mistakes :
terminate called after throwing an instance of ' c10::Error ' what(): read_bytes == 8 ASSERT FAILED at ../caffe2/serialize/inline_container.h:182, please report a bug to PyTorch. Expected to read 8 bytes but got %llu bytes0 (read64BitIntegerLittleEndian at ../caffe2/serialize/inline_container.h:182) frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6a (0x7fb099312aaa in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libc10.so) frame #1: <unknown function> + 0x5b1a86 (0x7fb0aefada86 in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #2: <unknown function> + 0x5b3aef (0x7fb0aefafaef in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #3: <unknown function> + 0x5b490f (0x7fb0aefb090f in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #4: <unknown function> + 0x5ad9cc (0x7fb0aefa99cc in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #5: torch::jit::load(std::istream&) + 0x2f4 (0x7fb0aefad2d4 in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #6: torch::jit::load(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x3a (0x7fb0aefad49a in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #7: main + 0x2d6 (0x409546 in /home/shining/Projects/github-projects/pytorch-project/cpp-pytorch/Release/traced_model) frame #8: __libc_start_main + 0xf0 (0x7fb095d4a830 in /lib/x86_64-linux-gnu/libc.so.6) frame #9: _start + 0x29 (0x409c39 in /home/shining/Projects/github-projects/pytorch-project/cpp-pytorch/Release/traced_model)

And I know why this error happen,because the libmaskrcnn_benchmark_customops.so is not added into link file,but I added LINK_DIRECTORIES and target_link_libraries in CMakelist.txt,it not worked either.
So can you give me a tutorial of your CMakelist.txt to run trace_model.cpp?
Thanks a lot!!!

@t-vi
Copy link
Author

@t-vi t-vi commented Nov 12, 2018

Oh dear. I forgot to include that. I didn't actually use cmake, just

g++ traced_model.cpp -lopencv_core -lopencv_highgui -lopencv_imgcodecs -I /usr/local/lib/python3.6/dist-packages/torch/lib/include/  -L /usr/local/lib/python3.6/dist-packages/torch/lib/ -ltorch -lcaffe2 -lc10 -lopencv_imgproc -L ../maskrcnn_benchmark/csrc/custom_ops/build/ -lmaskrcnn_benchmark_customops

and then I called it with LD_LIBRARY_PATH=/usr/local/lib/python3.6/dist-packages/torch/lib/:../maskrcnn_benchmark/csrc/custom_ops/build/ ./a.out.

It less than fancy, sorry about that.

@xxradon
Copy link

@xxradon xxradon commented Nov 12, 2018

Oh dear. I forgot to include that. I didn't actually use cmake, just

g++ traced_model.cpp -lopencv_core -lopencv_highgui -lopencv_imgcodecs -I /usr/local/lib/python3.6/dist-packages/torch/lib/include/  -L /usr/local/lib/python3.6/dist-packages/torch/lib/ -ltorch -lcaffe2 -lc10 -lopencv_imgproc -L ../maskrcnn_benchmark/csrc/custom_ops/build/ -lmaskrcnn_benchmark_customops

and then I called it with LD_LIBRARY_PATH=/usr/local/lib/python3.6/dist-packages/torch/lib/:../maskrcnn_benchmark/csrc/custom_ops/build/ ./a.out.

It less than fancy, sorry about that.

Thanks for your reply,but I got the same mistake...Can you try use CMakelist.txt instead,or give a more specific tutorial?
Thanks a lot.

@t-vi
Copy link
Author

@t-vi t-vi commented Nov 12, 2018

Can you actually read the .pt from torch?
(This is the exact error message you get when libtorch cannot open the file/it doesn't exist - I'll put up a PR for a better message.)

@Eric-Zhang1990
Copy link

@Eric-Zhang1990 Eric-Zhang1990 commented Nov 13, 2018

@t-vi Now I can trace and save model, but when I run your trace_model.cpp file, it shows an error about nms, which is:
Starting /home/eric/Disk100G/githubProject/trace_model_cpp_test/build/traced_model...
terminate called after throwing an instance of 'load model ok
torch::jit::script::ErrorReport'
what():
Schema not found for node. File a bug report.
Node: %3556 : Dynamic = maskrcnn_benchmark::nms(%3550, %3554, %3555)

Input types:Dynamic, Dynamic, float
candidates were:
.
Can you tell me how I can fix it?
Thanks a lot.

@t-vi
Copy link
Author

@t-vi t-vi commented Nov 13, 2018

@Eric-Zhang1990 Did you find and link to the libmaskrcnn_benchmark_customops.so (see the g++ above - I'm looking to provide cmake)...

@t-vi
Copy link
Author

@t-vi t-vi commented Nov 13, 2018

With the latest set of changes you should get the custom ops built for you. The C++ demo now has a CMakeLists.txt, but it isn't built for now (for lack of a good place to put the binary).

t-vi added 2 commits Nov 13, 2018
With recent JIT improvements, we can use the op directly
setup.py Outdated Show resolved Hide resolved
Thank you, Francisco, for the hint!
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio

def forward(self, input, rois):
if torch._C._get_tracing_state(): # we cannot currently trace through the autograd function
return torch.ops.maskrcnn_benchmark.roi_align_forward(

This comment has been minimized.

@fmassa

fmassa Nov 13, 2018
Contributor

question: do we need to register the backwards for the op in C++ then so that we can perform training properly using the same codepath?

This comment has been minimized.

@t-vi

t-vi Nov 13, 2018
Author

The current obstacle here is that we cannot trace through the Python autograd function.
I'm not aware of a way to register the derivative so we could go through the op directly during training as well.

This comment has been minimized.

@fmassa

fmassa Nov 13, 2018
Contributor

Thanks! @goldsborough do you know if we can register derivatives for custom ops?

This comment has been minimized.

@t-vi

t-vi Nov 14, 2018
Author

Peter responded elsewhere, and "not quite yet", but it is a to-do on his list.

@nicolasCruzW21
Copy link

@nicolasCruzW21 nicolasCruzW21 commented May 30, 2019

@t-vi and for any person wanting to use this on GPU I managed to export an end to end model.

So for anyone wanting to use this here are the instructions.
You will need:
Pytorch-nightly 1.0.0
Pytorch 1.0.1 (won't work with 1.0.0 nor with 1.1)
Opencv 3.2.

To install follow the instructions in install.md but do not build the project.
Instead go to the folder maskrcnn_benchmark and make a new folder called "lib".
Then:

cd $INSTALL_DIR
cd maskrcnn-benchmark
python setup.py build develop
cd maskrcnn_benchmark
cd lib 
ln -s ../../build/lib.linux-x86_64-3.6/maskrcnn_benchmark/lib/libmaskrcnn_benchmark_customops.so libmaskrcnn_benchmark_customops.so

now, go to the demo folder and put two images test1.jpg and test2.jpg in this folder.

create a new file or replace trace_model.py with this code:

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from __future__ import division
import os

import numpy
from io import BytesIO
from matplotlib import pyplot

import requests
import torch
from torch.jit import ScriptModule, script_method, trace, Tensor
from PIL import Image
from maskrcnn_benchmark.config import cfg
from predictor import COCODemo
from maskrcnn_benchmark.structures.image_list import ImageList

if __name__ == "__main__":
    # load config from file and command-line arguments

    project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    cfg.merge_from_file(
        os.path.join(project_dir,
                     "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml"))
    #cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
    cfg.freeze()

    # prepare object that handles inference plus adds predictions on top of image
    coco_demo = COCODemo(
        cfg,
        confidence_threshold=0.7,
        show_mask_heatmaps=False,
        masks_per_dim=2,
        min_image_size=480,
    )


def single_image_to_top_predictions(image):
    image = image.float() / 255.0
    image = image.permute(2, 0, 1)
    # we are loading images with OpenCV, so we don't need to convert them
    # to BGR, they are already! So all we need to do is to normalize
    # by 255 if we want to convert to BGR255 format, or flip the channels
    # if we want it to be in RGB in [0-1] range.
    if cfg.INPUT.TO_BGR255:
        image = image * 255
    else:
        image = image[[2, 1, 0]]

    # we absolutely want fixed size (int) here (or we run into a tracing error (or bug?)
    # or we might later decide to make things work with variable size...
    image = image - torch.tensor(cfg.INPUT.PIXEL_MEAN)[:, None, None]
    # should also do variance...
    image_list = ImageList(image.unsqueeze(0), [(int(image.size(-2)), int(image.size(-1)))])
    image_list = image_list.to(coco_demo.device)
    results = coco_demo.model(image_list)
    results = [o.to(coco_demo.cpu_device) for o in results]
    result=results[0]
    scores = result.get_field("scores")
    keep = (scores >= coco_demo.confidence_threshold)
    result = (result.bbox[keep],
              result.get_field("labels")[keep],
              result.get_field("mask")[keep],
              scores[keep])
    return result


@torch.jit.script
def my_paste_mask(mask, bbox, height, width, threshold=0.5, padding=1, contour=True, rectangle=False):
    # type: (Tensor, Tensor, int, int, float, int, bool, bool) -> Tensor
    padded_mask = torch.constant_pad_nd(mask, (padding, padding, padding, padding))
    scale = 1.0 + 2.0 * float(padding) / float(mask.size(-1))
    center_x = (bbox[2] + bbox[0]) * 0.5
    center_y = (bbox[3] + bbox[1]) * 0.5
    w_2 = (bbox[2] - bbox[0]) * 0.5 * scale
    h_2 = (bbox[3] - bbox[1]) * 0.5 * scale  # should have two scales?
    bbox_scaled = torch.stack([center_x - w_2, center_y - h_2,
                               center_x + w_2, center_y + h_2], 0)

    TO_REMOVE = 1
    w = (bbox_scaled[2] - bbox_scaled[0] + TO_REMOVE).clamp(min=1).long()
    h = (bbox_scaled[3] - bbox_scaled[1] + TO_REMOVE).clamp(min=1).long()

    scaled_mask = torch.ops.maskrcnn_benchmark.upsample_bilinear(padded_mask.float(), h, w)

    x0 = bbox_scaled[0].long()
    y0 = bbox_scaled[1].long()
    x = x0.clamp(min=0)
    y = y0.clamp(min=0)
    leftcrop = x - x0
    topcrop = y - y0
    w = torch.min(w - leftcrop, width - x)
    h = torch.min(h - topcrop, height - y)

    # mask = torch.zeros((height, width), dtype=torch.uint8)
    # mask[y:y + h, x:x + w] = (scaled_mask[topcrop:topcrop + h,  leftcrop:leftcrop + w] > threshold)
    mask = torch.constant_pad_nd((scaled_mask[topcrop:topcrop + h, leftcrop:leftcrop + w] > threshold),
                                 (int(x), int(width - x - w), int(y), int(height - y - h)))   # int for the script compiler

    if contour:
        mask = mask.float()
        # poor person's contour finding by comparing to smoothed
        mask = (mask - torch.nn.functional.conv2d(mask.unsqueeze(0).unsqueeze(0),
                                                  torch.full((1, 1, 3, 3), 1.0 / 9.0), padding=1)[0, 0]).abs() > 0.001
    if rectangle:
        x = torch.arange(width, dtype=torch.long).unsqueeze(0)
        y = torch.arange(height, dtype=torch.long).unsqueeze(1)
        r = bbox.long()
        # work around script not liking bitwise ops
        rectangle_mask = ((((x == r[0]) + (x == r[2])) * (y >= r[1]) * (y <= r[3]))
                          + (((y == r[1]) + (y == r[3])) * (x >= r[0]) * (x <= r[2])))
        mask = (mask + rectangle_mask).clamp(max=1)
    return mask


@torch.jit.script
def add_annotations(image, labels, scores, bboxes, class_names=','.join(coco_demo.CATEGORIES), color=torch.tensor([255, 255, 255], dtype=torch.long)):
    # type: (Tensor, Tensor, Tensor, Tensor, str, Tensor) -> Tensor
    result_image = torch.ops.maskrcnn_benchmark.add_annotations(image, labels, scores, bboxes, class_names, color)
    return result_image


@torch.jit.script
def combine_masks_tuple(input_model):
    # type: (Tuple[Tensor, Tensor, Tensor, Tensor, Tensor,Tensor]) -> Tensor
    image, bboxes, labels, masks, scores,palette=input_model
    threshold=0.5
    padding=1
    contour=True
    rectangle=False
    
    
    height = image.size(0)
    width = image.size(1)
    image_with_mask = image.clone()
    for i in range(masks.size(0)):
        color = ((palette * labels[i]) % 255).to(torch.uint8)
        one_mask = my_paste_mask(masks[i, 0], bboxes[i], height, width, threshold, padding, contour, rectangle)
        image_with_mask = torch.where(one_mask.unsqueeze(-1), color.unsqueeze(0).unsqueeze(0), image_with_mask)
    image_with_mask = add_annotations(image_with_mask, labels, scores, bboxes)
    return image_with_mask

def process_image_with_traced_model(image):
    boxes, labels, masks, scores = traced_model(image)
    result_image = combine_masks(image, labels, masks, scores, boxes, 0.5, 1, rectangle=True)
    return result_image

def fetch_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")

if __name__ == "__main__":
    pil_image =Image.open("test1.jpg").convert("RGB")
    pil_image = pil_image.resize((640, 480), Image.BILINEAR)

    # convert to BGR format
    image = torch.from_numpy(numpy.array(pil_image)[:, :, [2, 1, 0]])
    original_image = image

    if coco_demo.cfg.DATALOADER.SIZE_DIVISIBILITY:
        assert (image.size(0) % coco_demo.cfg.DATALOADER.SIZE_DIVISIBILITY == 0
                and image.size(1) % coco_demo.cfg.DATALOADER.SIZE_DIVISIBILITY == 0)

    for p in coco_demo.model.parameters():
        p.requires_grad_(False)



    traced_model = torch.jit.trace(single_image_to_top_predictions, (image,),optimize=False)
    traced_model.save('traced.pt')
    #@torch.jit.script
    def end_to_end_model(image):
        boxes, labels, masks, scores = single_image_to_top_predictions(image)
        palette=torch.tensor([3, 32767, 2097151])
        input_model=image, boxes, labels, masks, scores, palette
        result_image = combine_masks_tuple(input_model)
        return result_image
    
    end_to_end_model_traced = torch.jit.trace(end_to_end_model, (image,),optimize=False)
    end_to_end_model_traced.save('end_to_end_model.pt')

    image3 =Image.open("test2.jpg").convert("RGB")
    image3 = image3.resize((640, 480), Image.BILINEAR)
    image3 = torch.from_numpy(numpy.array(image3)[:, :, [2, 1, 0]])
    loaded = torch.jit.load("end_to_end_model.pt")
    result_image3 = loaded(image3)
    pyplot.imshow(result_image3[:, :, [2, 1, 0]])
    pyplot.show()

This will trace the model with CUDA enabled.
If you want to trace the model using CPU just uncomment the line:
#cfg.merge_from_list(["MODEL.DEVICE", "cpu"])

Two outputs are provided. The first one is just the tracing of the model.
The second one is the tracing including the code to plot an image.

If you want to output an image in opencv to plot the results use the example already provided in the cpp folder. The speed is 0.28 s in average on a 2080 ti, much faster than the CPU version which runs at about 2 s.

If you want to only use the model you can access the contents mask, scores, bboxes and labels by importing the "traced.pt" model and using:

auto result = moduleTorch->forward(inputs);
auto tuple=(result.toTuple())->elements();
auto bboxes = tuple[0].toTensor();
auto labels = tuple[1].toTensor();
auto mask = tuple[2].toTensor();
 auto scores = tuple[3].toTensor();

Hope this helps someone!
And thanks for all the help too.

@sukisleep
Copy link

@sukisleep sukisleep commented May 30, 2019

@t-vi and for any person wanting to use this on GPU I managed to export an end to end model.

So for anyone wanting to use this here are the instructions.
You will need:
Pytorch-nightly 1.0.0
Pytorch 1.0.1 (won't work with 1.0.0 nor with 1.1)
Opencv 3.2.

To install follow the instructions in install.md but do not build the project.
Instead go to the folder maskrcnn_benchmark and make a new folder called "lib".
Then:

cd $INSTALL_DIR
cd maskrcnn-benchmark
python setup.py build develop
cd maskrcnn_benchmark
cd lib 
ln -s ../../build/lib.linux-x86_64-3.6/maskrcnn_benchmark/lib/libmaskrcnn_benchmark_customops.so libmaskrcnn_benchmark_customops.so

now, go to the demo folder and put two images test1.jpg and test2.jpg in this folder.

create a new file or replace trace_model.py with this code:

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from __future__ import division
import os

import numpy
from io import BytesIO
from matplotlib import pyplot

import requests
import torch
from torch.jit import ScriptModule, script_method, trace, Tensor
from PIL import Image
from maskrcnn_benchmark.config import cfg
from predictor import COCODemo
from maskrcnn_benchmark.structures.image_list import ImageList

if __name__ == "__main__":
    # load config from file and command-line arguments

    project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    cfg.merge_from_file(
        os.path.join(project_dir,
                     "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml"))
    #cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
    cfg.freeze()

    # prepare object that handles inference plus adds predictions on top of image
    coco_demo = COCODemo(
        cfg,
        confidence_threshold=0.7,
        show_mask_heatmaps=False,
        masks_per_dim=2,
        min_image_size=480,
    )


def single_image_to_top_predictions(image):
    image = image.float() / 255.0
    image = image.permute(2, 0, 1)
    # we are loading images with OpenCV, so we don't need to convert them
    # to BGR, they are already! So all we need to do is to normalize
    # by 255 if we want to convert to BGR255 format, or flip the channels
    # if we want it to be in RGB in [0-1] range.
    if cfg.INPUT.TO_BGR255:
        image = image * 255
    else:
        image = image[[2, 1, 0]]

    # we absolutely want fixed size (int) here (or we run into a tracing error (or bug?)
    # or we might later decide to make things work with variable size...
    image = image - torch.tensor(cfg.INPUT.PIXEL_MEAN)[:, None, None]
    # should also do variance...
    image_list = ImageList(image.unsqueeze(0), [(int(image.size(-2)), int(image.size(-1)))])
    image_list = image_list.to(coco_demo.device)
    results = coco_demo.model(image_list)
    results = [o.to(coco_demo.cpu_device) for o in results]
    result=results[0]
    scores = result.get_field("scores")
    keep = (scores >= coco_demo.confidence_threshold)
    result = (result.bbox[keep],
              result.get_field("labels")[keep],
              result.get_field("mask")[keep],
              scores[keep])
    return result


@torch.jit.script
def my_paste_mask(mask, bbox, height, width, threshold=0.5, padding=1, contour=True, rectangle=False):
    # type: (Tensor, Tensor, int, int, float, int, bool, bool) -> Tensor
    padded_mask = torch.constant_pad_nd(mask, (padding, padding, padding, padding))
    scale = 1.0 + 2.0 * float(padding) / float(mask.size(-1))
    center_x = (bbox[2] + bbox[0]) * 0.5
    center_y = (bbox[3] + bbox[1]) * 0.5
    w_2 = (bbox[2] - bbox[0]) * 0.5 * scale
    h_2 = (bbox[3] - bbox[1]) * 0.5 * scale  # should have two scales?
    bbox_scaled = torch.stack([center_x - w_2, center_y - h_2,
                               center_x + w_2, center_y + h_2], 0)

    TO_REMOVE = 1
    w = (bbox_scaled[2] - bbox_scaled[0] + TO_REMOVE).clamp(min=1).long()
    h = (bbox_scaled[3] - bbox_scaled[1] + TO_REMOVE).clamp(min=1).long()

    scaled_mask = torch.ops.maskrcnn_benchmark.upsample_bilinear(padded_mask.float(), h, w)

    x0 = bbox_scaled[0].long()
    y0 = bbox_scaled[1].long()
    x = x0.clamp(min=0)
    y = y0.clamp(min=0)
    leftcrop = x - x0
    topcrop = y - y0
    w = torch.min(w - leftcrop, width - x)
    h = torch.min(h - topcrop, height - y)

    # mask = torch.zeros((height, width), dtype=torch.uint8)
    # mask[y:y + h, x:x + w] = (scaled_mask[topcrop:topcrop + h,  leftcrop:leftcrop + w] > threshold)
    mask = torch.constant_pad_nd((scaled_mask[topcrop:topcrop + h, leftcrop:leftcrop + w] > threshold),
                                 (int(x), int(width - x - w), int(y), int(height - y - h)))   # int for the script compiler

    if contour:
        mask = mask.float()
        # poor person's contour finding by comparing to smoothed
        mask = (mask - torch.nn.functional.conv2d(mask.unsqueeze(0).unsqueeze(0),
                                                  torch.full((1, 1, 3, 3), 1.0 / 9.0), padding=1)[0, 0]).abs() > 0.001
    if rectangle:
        x = torch.arange(width, dtype=torch.long).unsqueeze(0)
        y = torch.arange(height, dtype=torch.long).unsqueeze(1)
        r = bbox.long()
        # work around script not liking bitwise ops
        rectangle_mask = ((((x == r[0]) + (x == r[2])) * (y >= r[1]) * (y <= r[3]))
                          + (((y == r[1]) + (y == r[3])) * (x >= r[0]) * (x <= r[2])))
        mask = (mask + rectangle_mask).clamp(max=1)
    return mask


@torch.jit.script
def add_annotations(image, labels, scores, bboxes, class_names=','.join(coco_demo.CATEGORIES), color=torch.tensor([255, 255, 255], dtype=torch.long)):
    # type: (Tensor, Tensor, Tensor, Tensor, str, Tensor) -> Tensor
    result_image = torch.ops.maskrcnn_benchmark.add_annotations(image, labels, scores, bboxes, class_names, color)
    return result_image


@torch.jit.script
def combine_masks_tuple(input_model):
    # type: (Tuple[Tensor, Tensor, Tensor, Tensor, Tensor,Tensor]) -> Tensor
    image, bboxes, labels, masks, scores,palette=input_model
    threshold=0.5
    padding=1
    contour=True
    rectangle=False
    
    
    height = image.size(0)
    width = image.size(1)
    image_with_mask = image.clone()
    for i in range(masks.size(0)):
        color = ((palette * labels[i]) % 255).to(torch.uint8)
        one_mask = my_paste_mask(masks[i, 0], bboxes[i], height, width, threshold, padding, contour, rectangle)
        image_with_mask = torch.where(one_mask.unsqueeze(-1), color.unsqueeze(0).unsqueeze(0), image_with_mask)
    image_with_mask = add_annotations(image_with_mask, labels, scores, bboxes)
    return image_with_mask

def process_image_with_traced_model(image):
    boxes, labels, masks, scores = traced_model(image)
    result_image = combine_masks(image, labels, masks, scores, boxes, 0.5, 1, rectangle=True)
    return result_image

def fetch_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")

if __name__ == "__main__":
    pil_image =Image.open("test1.jpg").convert("RGB")
    pil_image = pil_image.resize((640, 480), Image.BILINEAR)

    # convert to BGR format
    image = torch.from_numpy(numpy.array(pil_image)[:, :, [2, 1, 0]])
    original_image = image

    if coco_demo.cfg.DATALOADER.SIZE_DIVISIBILITY:
        assert (image.size(0) % coco_demo.cfg.DATALOADER.SIZE_DIVISIBILITY == 0
                and image.size(1) % coco_demo.cfg.DATALOADER.SIZE_DIVISIBILITY == 0)

    for p in coco_demo.model.parameters():
        p.requires_grad_(False)



    traced_model = torch.jit.trace(single_image_to_top_predictions, (image,),optimize=False)
    traced_model.save('traced.pt')
    #@torch.jit.script
    def end_to_end_model(image):
        boxes, labels, masks, scores = single_image_to_top_predictions(image)
        palette=torch.tensor([3, 32767, 2097151])
        input_model=image, boxes, labels, masks, scores, palette
        result_image = combine_masks_tuple(input_model)
        return result_image
    
    end_to_end_model_traced = torch.jit.trace(end_to_end_model, (image,),optimize=False)
    end_to_end_model_traced.save('end_to_end_model.pt')

    image3 =Image.open("test2.jpg").convert("RGB")
    image3 = image3.resize((640, 480), Image.BILINEAR)
    image3 = torch.from_numpy(numpy.array(image3)[:, :, [2, 1, 0]])
    loaded = torch.jit.load("end_to_end_model.pt")
    result_image3 = loaded(image3)
    pyplot.imshow(result_image3[:, :, [2, 1, 0]])
    pyplot.show()

This will trace the model with CUDA enabled.
If you want to trace the model using CPU just uncomment the line:
#cfg.merge_from_list(["MODEL.DEVICE", "cpu"])

Two outputs are provided. The first one is just the tracing of the model.
The second one is the tracing including the code to plot an image.

If you want to output an image in opencv to plot the results use the example already provided in the cpp folder. The speed is 0.3 ms in average, much faster than the CPU version.

If you want to only use the model you can access the contents mask, scores, bboxes and labels by importing the "traced.pt" model and using:

auto result = moduleTorch->forward(inputs);
auto tuple=(result.toTuple())->elements();
auto bboxes = tuple[0].toTensor();
auto labels = tuple[1].toTensor();
auto mask = tuple[2].toTensor();
 auto scores = tuple[3].toTensor();

Hope this helps someone!
And thanks for all the help too.

thank you for your sharing!

if possible, could you share your c++ infer code for us, anyway, thank you so much again for your excellence sharing

@nicolasCruzW21
Copy link

@nicolasCruzW21 nicolasCruzW21 commented May 30, 2019

@sukisleep I actually work on a C++ framework, so it will have some differences when you want to implement it on your own code. Here is what I did:

First, download the appropriate version of the c++ pytorch. Since the python side of things only works on pytorch 1.0.1 and since i'm using CUDA 10 that would be "https://download.pytorch.org/libtorch/cu100/libtorch-shared-with-deps-1.0.1.zip".

Once you have that you need to choose your version of c++. My framework is in c++14 so it wont work with the downloaded package out of the box. To fix the error go to libtorch/share/cmake and in each of those files if you find something like:
CXX_STANDARD 11

or

-std=c++11"
change the 11 to a 14 to get it to work.
The package also comes with the old ABI, so to link the pytorch library without any undefined simbols I had to compile both my code and opencv with the old ABI.
To do this go to your opencv folder (remember opencv 3.2) and at the beginning of the cmake add this line:
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
now compile it and install it.
now lets go to the Cmake of your code.
add the following line to compile with the old ABI.
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)

First you will need CUDA and OpenCV of course, so you need these lines:

find_package(OpenCV 3 REQUIRED)
find_package(CUDA 10.0 REQUIRED)

You also need to import libtorch, so use this line.
find_package(Torch REQUIRED PATHS <path to libtorch>/libtorch)
(you need to replace the path in accordance to where your libtorch package is located).

Now let's include the relevant directories of our libs:

include_directories(${OpenCV_INCLUDE} ${TORCH_INCLUDE_DIRS} ${Any other stuff you want to add})

And finally, let's link the libraries.
target_link_libraries( YourProject ${OpenCV_LIBS} ${TORCH_LIBRARIES} ${Any other stuff you want to add})

Now you can compile your code with pytorch libraries.
Suppose that you want to process images from a camera using mask-rcnn. You want the operation to be fast (so you will need the cuda version of the end_to_end model) and you want to load the model only once.

So in your File.h you want to import the following:

//pytorch
#include <torch/script.h>
#include <torch/csrc/jit/import.h>
#include <torch/torch.h>
#include <torch/csrc/api/include/torch/jit.h>
//pytorchEnd
#include <dlfcn.h>

in File.h you want to define some variables:

class maskRCNN : public maskRCNNBase
{
public:
    maskRCNN();
    std::shared_ptr<torch::jit::script::Module> maskRCNN;

private:

    //some other functions that you may want

    torch::DeviceType device_type;
    at::Tensor output;
    float* bufferImgIn;
    cv::Mat img; //mask-rcnn
    void* custom_op_lib;
};

Now for the cpp, File.cpp
First, you want to load the model in the constructor as well as your custom ops and also alocate space for your input, all of this must be done in the constructor.

maskRCNN::maskRCNN()
{

    if (torch::cuda::is_available())
    {
        std::string libtorchFolder =  "/<path to where you .pt model is stored>";
        const std::string traceFile = libtorchFolder + "end_to_end_model.pt";
        device_type = torch::kCUDA;
        torch::Device device(device_type);


        const char* customOpsFolder=("<path to your lib folder>/libmaskrcnn_benchmark_customops.so").c_str();
        custom_op_lib = dlopen(customOpsFolder, RTLD_NOW | RTLD_GLOBAL);//load custom ops

        if (custom_op_lib == NULL) {//are they loaded?
          std::cerr << "could not open custom op library: " << dlerror() << std::endl;
          ASSERT(false);
        }
        maskRCNN = torch::jit::load(traceFile);//load
        maskRCNN->to(at::kCUDA);//to cuda

        img=cv::Mat(480, 640, CV_8UC3);

    }
    else
    {
        ASSERT(false);
    }

    ASSERT(maskRCNN != nullptr);
}

And once that is done the rest is very straight foward.
Suppose that your input is an opencv image that you get from a camera called theImage. Then do this:

cv::resize(theImage, img, img.size(), 0, 0, cv::INTER_AREA);
    auto input_ = torch::tensor(at::ArrayRef<uint8_t>(img.data, img.rows * img.cols * 3)).view({img.rows, img.cols, 3});
std::vector<torch::jit::IValue> inputs;
inputs.push_back(input_);
auto res = maskRCNN->forward(inputs).toTensor();
inputs.clear();
cv::Mat cv_res(res.size(0), res.size(1), CV_8UC3, (void*) res.data<uint8_t>());
cv::namedWindow("Detected", cv::WINDOW_AUTOSIZE);
cv::imshow("Detected", cv_res);

And that is all, you have your output in c++.

@deeponcology
Copy link

@deeponcology deeponcology commented Jun 7, 2019

And that is all, you have your output in c++.

nicolasCruzW21 Thank you so much for the very detailed description.
Can you kindly share a minimally workable example on Git?

Thanks.

@jinfagang
Copy link

@jinfagang jinfagang commented Jun 8, 2019

@nicolasCruzW21 Glad you got it worked out. Can you elabrate how to get the libmask_rcnn_customops.so ? Do u have full project with CMakeLists.txt and file about inference in C++?

@imranparuk
Copy link

@imranparuk imranparuk commented Jun 19, 2019

Any update on the JIT tracing of BoxList?

@ys0823
Copy link

@ys0823 ys0823 commented Jul 7, 2019

Thanks for your work! I am a fresh. But when I install the environment according to INSTALL.md, when i compile the setup.py for the last step,by use python setup.py build develop,there is always a mistake like this:

copying build/lib.linux-x86_64-3.6/maskrcnn_benchmark/_C.cpython-36m-x86_64-linux-gnu.so -> maskrcnn_benchmark
copying build/lib.linux-x86_64-3.6/maskrcnn_benchmark/lib/custom_ops.cpython-36m-x86_64-linux-gnu.so -> maskrcnn_benchmark/lib
error: could not create 'maskrcnn_benchmark/lib/custom_ops.cpython-36m-x86_64-linux-gnu.so': No such file or directory

my sets are:
pytorch 1.00
opencv3.2
gcc7.3
I feel strange, is there any idea for this?

@ys0823
Copy link

@ys0823 ys0823 commented Jul 15, 2019

Thanks for the discuss section,I can get the result by libtorch,but I have a problem when I test the image which has none objects(my work is object detection,is this the code bug or my fault?
The error is:
terminate called after throwing an instance of 'std::runtime_error' what(): cuDNN error: CUDNN_STATUS_BAD_PARAM (set at /pytorch/aten/src/ATen/cudnn/Descriptors.h:127) frame #0: std::function<std::string ()>::operator()() const + 0x11 (0x7fd7585d4021 in /data/tanglin/libtorch/lib/libc10.so) frame #1: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x2a (0x7fd7585d38ea in /data/tanglin/libtorch/lib/libc10.so) frame #2: at::native::TensorDescriptor::set(cudnnDataType_t, c10::ArrayRef<long>, c10::ArrayRef<long>, unsigned long) + 0x413 (0x7fd759bed883 in /data/tanglin/libtorch/lib/libcaffe2_gpu.so) frame #3: at::native::TensorDescriptor::set(at::Tensor const&, unsigned long) + 0x272 (0x7fd759bee242 in /data/tanglin/libtorch/lib/libcaffe2_gpu.so) frame #4: at::native::raw_cudnn_convolution_forward_out(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) + 0x104 (0x7fd759a14774 in /data/tanglin/libtorch/lib/libcaffe2_gpu.so) frame #5: at::native::cudnn_convolution_forward(char const*, at::TensorArg const&, at::TensorArg const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) + 0x45e (0x7fd759a1516e in /data/tanglin/libtorch/lib/libcaffe2_gpu.so) frame #6: at::native::cudnn_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) + 0x187 (0x7fd759a156a7 in /data/tanglin/libtorch/lib/libcaffe2_gpu.so) frame #7: at::CUDAFloatType::cudnn_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) const + 0xb2 (0x7fd759aefb62 in /data/tanglin/libtorch/lib/libcaffe2_gpu.so) frame #8: torch::autograd::VariableType::cudnn_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) const + 0x2d5 (0x7fd79694ae85 in /data/tanglin/libtorch/lib/libtorch.so.1) frame #9: at::native::_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool) + 0x16e4 (0x7fd787211fa4 in /data/tanglin/libtorch/lib/libcaffe2.so) frame #10: at::TypeDefault::_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool) const + 0xce (0x7fd78750f03e in /data/tanglin/libtorch/lib/libcaffe2.so) frame #11: torch::autograd::VariableType::_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool) const + 0x26a (0x7fd7968f118a in /data/tanglin/libtorch/lib/libtorch.so.1) frame #12: <unknown function> + 0x58721b (0x7fd796a8721b in /data/tanglin/libtorch/lib/libtorch.so.1) frame #13: <unknown function> + 0x672286 (0x7fd796b72286 in /data/tanglin/libtorch/lib/libtorch.so.1) frame #14: torch::jit::InterpreterState::run(std::vector<c10::IValue, std::allocator<c10::IValue> >&) + 0x22 (0x7fd796b6d842 in /data/tanglin/libtorch/lib/libtorch.so.1) frame #15: <unknown function> + 0x65c6ac (0x7fd796b5c6ac in /data/tanglin/libtorch/lib/libtorch.so.1) frame #16: ./Peoplecount() [0x404efc] frame #17: __libc_start_main + 0xf0 (0x7fd753125830 in /lib/x86_64-linux-gnu/libc.so.6) frame #18: ./Peoplecount() [0x406219] : operation failed in interpreter: _407 = [_401, _403, _405, torch.to(_406, 6, False, False)] first_result = torch.select(_407, 0) dtype = ops.prim.dtype(first_result) device = ops.prim.device(first_result) _408 = [torch.size(levels, 0), torch.size(first_result, 1), torch.size(first_result, 2), torch.size(first_result, 3)] res = torch.zeros(_408, dtype=dtype, layout=0, device=device) for l in range(torch.len(_407)): _409 = torch.view(torch.eq(levels, l), [-1, 1, 1, 1]) _410 = torch.masked_scatter_(res, _409, torch.select(_407, l)) input_79 = torch._convolution(res, CONSTANTS.c218, CONSTANTS.c17, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True) ~~~~~~~~~~~~~~~~~~ <--- HERE input_80 = torch.relu(input_79) input_81 = torch._convolution(input_80, CONSTANTS.c219, CONSTANTS.c17, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True) input_82 = torch.relu(input_81) input_83 = torch._convolution(input_82, CONSTANTS.c220, CONSTANTS.c17, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True) input_84 = torch.relu(input_83) input_85 = torch._convolution(input_84, CONSTANTS.c221, CONSTANTS.c17, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True) input_86 = torch.relu(input_85) input_87 = torch._convolution(input_86, CONSTANTS.c222, CONSTANTS.c17, [2, 2], [0, 0], [1, 1], True, [0, 0], 1, False, False, True) input = torch.relu(input_87) Aborted (core dumped)
Could somebody tell me the reason?

@ys0823
Copy link

@ys0823 ys0823 commented Jul 15, 2019

Thanks for your work! I am a fresh. But when I install the environment according to INSTALL.md, when i compile the setup.py for the last step,by use python setup.py build develop,there is always a mistake like this:

copying build/lib.linux-x86_64-3.6/maskrcnn_benchmark/_C.cpython-36m-x86_64-linux-gnu.so -> maskrcnn_benchmark
copying build/lib.linux-x86_64-3.6/maskrcnn_benchmark/lib/custom_ops.cpython-36m-x86_64-linux-gnu.so -> maskrcnn_benchmark/lib
error: could not create 'maskrcnn_benchmark/lib/custom_ops.cpython-36m-x86_64-linux-gnu.so': No such file or directory

my sets are:
pytorch 1.00
opencv3.2
gcc7.3
I feel strange, is there any idea for this?

I solved it by make a new dir name lib in path ../name/maskrcnn-benchmark/maskrcnn_benchmark/lib

@zimenglan-sysu-512
Copy link
Contributor

@zimenglan-sysu-512 zimenglan-sysu-512 commented Jul 25, 2019

hi @t-vi
when build the csrc, i get the error /usr/bin/ld: cannot find -lopencv_imgcodecs, how do u fix it?
thanks

@zimenglan-sysu-512
Copy link
Contributor

@zimenglan-sysu-512 zimenglan-sysu-512 commented Jul 26, 2019

and also meet the problem: OSError: maskrcnn_benchmark/lib/libmaskrcnn_benchmark_customops.so: undefined symbol: _ZN2cv7putTextERNS_3MatERKSsNS_6Point_IiEEidNS_7Scalar_IdEEiib
hope someone can help me out.
thanks

@ys0823
Copy link

@ys0823 ys0823 commented Jul 27, 2019

and also meet the problem: OSError: maskrcnn_benchmark/lib/libmaskrcnn_benchmark_customops.so: undefined symbol: _ZN2cv7putTextERNS_3MatERKSsNS_6Point_IiEEidNS_7Scalar_IdEEiib
hope someone can help me out.
thanks

I meet the same problem for two reasons,first my gcc is lower than 4.9,second I don't create the maskrcnn ``project dir rightly, the file dir should be
../your project
---apex
---cocoapi
---maskrcnn-benchmark
hope that will help you

@zimenglan-sysu-512
Copy link
Contributor

@zimenglan-sysu-512 zimenglan-sysu-512 commented Jul 30, 2019

it seems that boxlist does not support tracing. @imranparuk

@zimenglan-sysu-512
Copy link
Contributor

@zimenglan-sysu-512 zimenglan-sysu-512 commented Aug 5, 2019

hi @nicolasCruzW21
i find that if use gpu mode, it will encouter the problem:

vector::_M_range_check: __n (which is 18446744073709551615) >= this->size() (which is 2)

to solve it, u can set dim=-1 to dim=1 in bounding_box.py file. for more detail, u can see this link.

hope it can help u.

@ys0823
Copy link

@ys0823 ys0823 commented Aug 11, 2019

@t-vi Thanks for your reply~ Form the talk above, I thought it can works for different images, which made me confused. And actually I met the shape error both on CPU and GPU.

@zhuqiang00099 Hi~Can you give me some details about how you get the custom_ops.dll? I tried to build it separately and sure to include and compile the *.cu files, but when I use the *dll to trace_model, there is an error says it didn't compile with cuda.
Just confirm,I build it with CMake and VS15 and change the CMakeLists.txt
CUDA_ADD_LIBRARY(maskrcnn_benchmark_customops SHARED custom_ops.cpp ../cpu/nms_cpu.cpp ../cpu/ROIAlign_cpu.cpp ../cuda/nms.cu ../cuda/ROIAlign_cuda.cu)

Thanks first!

@jojojo29 I have the same problem in nms by cuda, for some reason, I must build it by Cmake, have you already got the nms by cuda?could you share us the way to get nms both by cpu and gpu by cmake?

@WonderAndMaps
Copy link

@WonderAndMaps WonderAndMaps commented Aug 13, 2019

@t-vi Thanks , your suggestion is right,I can run trace_model.py without mistake and got the end_to_end_model.pt file.And when I tested trace_model.cpp,built was ok,but when I ran the demo there are some mistakes :
terminate called after throwing an instance of ' c10::Error ' what(): read_bytes == 8 ASSERT FAILED at ../caffe2/serialize/inline_container.h:182, please report a bug to PyTorch. Expected to read 8 bytes but got %llu bytes0 (read64BitIntegerLittleEndian at ../caffe2/serialize/inline_container.h:182) frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6a (0x7fb099312aaa in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libc10.so) frame #1: <unknown function> + 0x5b1a86 (0x7fb0aefada86 in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #2: <unknown function> + 0x5b3aef (0x7fb0aefafaef in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #3: <unknown function> + 0x5b490f (0x7fb0aefb090f in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #4: <unknown function> + 0x5ad9cc (0x7fb0aefa99cc in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #5: torch::jit::load(std::istream&) + 0x2f4 (0x7fb0aefad2d4 in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #6: torch::jit::load(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x3a (0x7fb0aefad49a in /home/shining/Projects/github-projects/pytorch-project/pytorch-build/torch/lib/tmp_install/lib/libtorch.so.1) frame #7: main + 0x2d6 (0x409546 in /home/shining/Projects/github-projects/pytorch-project/cpp-pytorch/Release/traced_model) frame #8: __libc_start_main + 0xf0 (0x7fb095d4a830 in /lib/x86_64-linux-gnu/libc.so.6) frame #9: _start + 0x29 (0x409c39 in /home/shining/Projects/github-projects/pytorch-project/cpp-pytorch/Release/traced_model)

And I know why this error happen,because the libmaskrcnn_benchmark_customops.so is not added into link file,but I added LINK_DIRECTORIES and target_link_libraries in CMakelist.txt,it not worked either.
So can you give me a tutorial of your CMakelist.txt to run trace_model.cpp?
Thanks a lot!!!

@xxradon Hi, I just came across this earlier. Can you elaborate a little bit more on how you fix the libtorch problem? Many thanks.

@t-vi
Copy link
Author

@t-vi t-vi commented Oct 10, 2019

So given that torchvision is the place to go for things like those explored here, I'm closing this PR.

@t-vi t-vi closed this Oct 10, 2019
@fmassa
Copy link
Contributor

@fmassa fmassa commented Oct 11, 2019

Thanks a lot for all your work and help Thomas!

@engineer1109
Copy link

@engineer1109 engineer1109 commented Feb 18, 2020

@t-vi, what is your pytorch vision?
I got this error :
RuntimeError: Tried to trace <torch.maskrcnn_benchmark.modeling.roi_heads.box_head.inference.PostProcessor object at 0x5c35f20> but it is not part of the active trace. Modules that are called during a trace must be registered as submodules of the thing being traced.

Torch 1.3.1

@kewin1807
Copy link

@kewin1807 kewin1807 commented Feb 28, 2020

how can you push mask rcnn on Android. It is amazing. Can you share keys for this or repo. Thanks

@nicolasCruzW21
Copy link

@nicolasCruzW21 nicolasCruzW21 commented Jun 4, 2020

For anyone still intrested in this, I made a repo compatible with pytorch 1.5 to export the cuda enabled versions of the models rather than the cpu versions. Also fixed the issue of crashing when no objects are in the image. The code is ugly since this is a port of the original 1.0.1 version. Thanks to @t-vi again for all the help and developing the vast majority of the code. You can find the code here.

https://github.com/nicolasCruzW21/maskrcnn-Tracing.git

I hope it's useful

@bmabir17
Copy link

@bmabir17 bmabir17 commented Jun 29, 2020

@t-vi were you able to run the mask-rcnn conversion as hinted on https://lernapparat.de/pytorch-android/ ? if so could you kindly give us a hint, what would be the way to do this?

@t-vi
Copy link
Author

@t-vi t-vi commented Jun 29, 2020

Yes, that was fun in 2018. I'd recommend using the TorchVision provided code or @nicolasCruzW21 branch over anything in this PR.

@adibaig1
Copy link

@adibaig1 adibaig1 commented Jan 26, 2021

For anyone still intrested in this, I made a repo compatible with pytorch 1.5 to export the cuda enabled versions of the models rather than the cpu versions. Also fixed the issue of crashing when no objects are in the image. The code is ugly since this is a port of the original 1.0.1 version. Thanks to @t-vi again for all the help and developing the vast majority of the code. You can find the code here.

https://github.com/nicolasCruzW21/maskrcnn-Tracing.git

I hope it's useful

Great work.
I am interested in getting the scripted model from maskrcnn benchmark. I will give it a shot and update soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet