Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export Detic to ONNX with custom vocabulary #113

Open
gigasurgeon opened this issue Nov 16, 2023 · 3 comments
Open

Export Detic to ONNX with custom vocabulary #113

gigasurgeon opened this issue Nov 16, 2023 · 3 comments

Comments

@gigasurgeon
Copy link

gigasurgeon commented Nov 16, 2023

I wanted to share the method to export detic model to ONNX format with custom vocabulary.

Step 1) First of all, comment out this line box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages) in custom_rcnn.py

Step 2) Also, according to this comment #107 (comment) , you have to comment the nms_and_topk line in centernet, while exporting the model

boxlists = self.nms_and_topK(boxlists, nms=not self.not_nms)

Step 3) Now on to the main part. You need to modify this file -> Detic/detectron2/tools/deploy/export_model.py

This is the final script I had

#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import os
from typing import Dict, List, Tuple
import torch
from torch import Tensor, nn

import sys
sys.path.insert(0, '/vmdata/amitsingh/workspace/Detic')
sys.path.insert(0, '/vmdata/amitsingh/workspace/Detic/third_party/CenterNet2')

import detectron2.data.transforms as T
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import build_detection_test_loader, detection_utils
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format
from detectron2.export import (
    STABLE_ONNX_OPSET_VERSION,
    TracingAdapter,
    dump_torchscript_IR,
    scripting_with_instances,
)
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import add_pointrend_config
from detectron2.structures import Boxes
from detectron2.utils.env import TORCH_VERSION
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
from centernet.config import add_centernet_config
from detic.config import add_detic_config


def setup_cfg(args):
    cfg = get_cfg()
    # cuda context is initialized before creating dataloader, so we don't fork anymore
    cfg.DATALOADER.NUM_WORKERS = 0
    add_pointrend_config(cfg)
    add_centernet_config(cfg)
    add_detic_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg


def export_caffe2_tracing(cfg, torch_model, inputs):
    from detectron2.export import Caffe2Tracer

    tracer = Caffe2Tracer(cfg, torch_model, inputs)
    if args.format == "caffe2":
        caffe2_model = tracer.export_caffe2()
        caffe2_model.save_protobuf(args.output)
        # draw the caffe2 graph
        caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=inputs)
        return caffe2_model
    elif args.format == "onnx":
        import onnx

        onnx_model = tracer.export_onnx()
        onnx.save(onnx_model, os.path.join(args.output, "model.onnx"))
    elif args.format == "torchscript":
        ts_model = tracer.export_torchscript()
        with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
            torch.jit.save(ts_model, f)
        dump_torchscript_IR(ts_model, args.output)


# experimental. API not yet final
def export_scripting(torch_model):
    assert TORCH_VERSION >= (1, 8)
    fields = {
        "proposal_boxes": Boxes,
        "objectness_logits": Tensor,
        "pred_boxes": Boxes,
        "scores": Tensor,
        "pred_classes": Tensor,
        "pred_masks": Tensor,
        "pred_keypoints": torch.Tensor,
        "pred_keypoint_heatmaps": torch.Tensor,
    }
    assert args.format == "torchscript", "Scripting only supports torchscript format."

    class ScriptableAdapterBase(nn.Module):
        # Use this adapter to workaround https://github.com/pytorch/pytorch/issues/46944
        # by not retuning instances but dicts. Otherwise the exported model is not deployable
        def __init__(self):
            super().__init__()
            self.model = torch_model
            self.eval()

    if isinstance(torch_model, GeneralizedRCNN):

        class ScriptableAdapter(ScriptableAdapterBase):
            def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]:
                instances = self.model.inference(inputs, do_postprocess=False)
                return [i.get_fields() for i in instances]

    else:

        class ScriptableAdapter(ScriptableAdapterBase):
            def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]:
                instances = self.model(inputs)
                return [i.get_fields() for i in instances]

    ts_model = scripting_with_instances(ScriptableAdapter(), fields)
    with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
        torch.jit.save(ts_model, f)
    dump_torchscript_IR(ts_model, args.output)
    # TODO inference in Python now missing postprocessing glue code
    return None


# experimental. API not yet final
def export_tracing(torch_model, inputs):
    assert TORCH_VERSION >= (1, 8)
    image = inputs[0]["image"]
    inputs = [{"image": image}]  # remove other unused keys

    if isinstance(torch_model, GeneralizedRCNN):

        def inference(model, inputs):
            # use do_postprocess=False so it returns ROI mask
            inst = model.inference(inputs, do_postprocess=False)[0]
            return [{"instances": inst}]

    else:
        inference = None  # assume that we just call the model directly

    traceable_model = TracingAdapter(torch_model, inputs, inference)

    if args.format == "torchscript":
        ts_model = torch.jit.trace(traceable_model, (image,))
        with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
            torch.jit.save(ts_model, f)
        dump_torchscript_IR(ts_model, args.output)
    elif args.format == "onnx":
        with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f:
            torch.onnx.export(traceable_model, (image,), f, opset_version=STABLE_ONNX_OPSET_VERSION)
    logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
    logger.info("Outputs schema: " + str(traceable_model.outputs_schema))

    if args.format != "torchscript":
        return None
    if not isinstance(torch_model, (GeneralizedRCNN, RetinaNet)):
        return None

    def eval_wrapper(inputs):
        """
        The exported model does not contain the final resize step, which is typically
        unused in deployment but needed for evaluation. We add it manually here.
        """
        input = inputs[0]
        instances = traceable_model.outputs_schema(ts_model(input["image"]))[0]["instances"]
        postprocessed = detector_postprocess(instances, input["height"], input["width"])
        return [{"instances": postprocessed}]

    return eval_wrapper


def get_sample_inputs(args):

    if args.sample_image is None:
        # get a first batch from dataset
        data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
        first_batch = next(iter(data_loader))
        return first_batch
    else:
        # get a sample data
        original_image = detection_utils.read_image(args.sample_image, format=cfg.INPUT.FORMAT)
        # Do same preprocessing as DefaultPredictor
        aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
        )
        height, width = original_image.shape[:2]
        image = aug.get_transform(original_image).apply_image(original_image)
        image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

        inputs = {"image": image, "height": height, "width": width}

        # Sample ready
        sample_inputs = [inputs]
        return sample_inputs


def get_clip_embeddings(vocabulary, prompt='a '):
    from detic.modeling.text.text_encoder import build_text_encoder
    text_encoder = build_text_encoder(pretrain=True)
    text_encoder.eval()
    texts = [prompt + x for x in vocabulary]
    emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
    return emb


def reset_cls_test(model, cls_path, num_classes):
    import numpy as np
    from torch.nn import functional as F

    model.roi_heads.num_classes = num_classes
    if type(cls_path) == str:
        print('Resetting zs_weight', cls_path)
        zs_weight = torch.tensor(
            np.load(cls_path),
            dtype=torch.float32).permute(1, 0).contiguous() # D x C
    else:
        zs_weight = cls_path
    zs_weight = torch.cat(
        [zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))],
        dim=1) # D x (C + 1)
    if model.roi_heads.box_predictor[0].cls_score.norm_weight:
        zs_weight = F.normalize(zs_weight, p=2, dim=0)
    zs_weight = zs_weight.to(model.device)
    for k in range(len(model.roi_heads.box_predictor)):
        del model.roi_heads.box_predictor[k].cls_score.zs_weight
        model.roi_heads.box_predictor[k].cls_score.zs_weight = zs_weight


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Export a model for deployment.")
    parser.add_argument(
        "--format",
        choices=["caffe2", "onnx", "torchscript"],
        help="output format",
        default="torchscript",
    )
    parser.add_argument(
        "--export-method",
        choices=["caffe2_tracing", "tracing", "scripting"],
        help="Method to export models",
        default="tracing",
    )
    parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
    parser.add_argument("--sample-image", default=None, type=str, help="sample image for input")
    parser.add_argument("--run-eval", action="store_true")
    parser.add_argument("--output", help="output directory for the converted model")
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()
    logger = setup_logger()
    logger.info("Command line arguments: " + str(args))
    PathManager.mkdirs(args.output)
    # Disable re-specialization on new shapes. Otherwise --run-eval will be slow
    torch._C._jit_set_bailout_depth(1)

    cfg = setup_cfg(args)

    # create a torch model with custom_classes
    custom_classes = ['scoop', 'teaspoon', 'spoon', 'tea_spoon', 'flatware', 'tong', 'coffee_spoon', 'soupspoon', 'soup_spoon', 'spatula', 'ladle', 'skimmer', 'bowl', 'egg_bowl', 'sugar_bowl', 'washing_bowl', 'salad_bowl', 'cereal_bowl', 'soup_bowl', 'saucepan', 'frying_pan', 'pan', 'cake_pan', 'sauce_pan', 'content_pan', 'wok', 'saucer', 'plate', 'chinaware', 'glass', 'wine_glass', 'chalice', 'dixie_cup', 'flute_glass', 'shot_glass', 'wineglass', 'milk_bottle', 'bottle', 'water_bottle', 'wine_bottle', 'beer_bottle', 'tea_pot', 'pot', 'pressure_pot', 'pasta_pot', 'plastic_pot', 'sauce_pot', 'teapot', 'crock_pot', 'crockpot', 'cup', 'measuring_cup', 'coffee_cup', 'mug', 'teacup', 'tea_cup', 'pitcher', 'coffee_jar', 'sugar_jar', 'honey_jar', 'jar', 'jug', 'coffeepot', 'kettle', 'water_jug', 'urn', 'cream_pitcher', 'coffee_pot', 'container', 'lunch_box', 'sugar_container', 'milk_container', 'rice_container', 'sauce_container', 'food_container', 'casserole', 'knife', 'steak_knife', 'knife_sharpener', 'lime_squeezer', 'peeler', 'grater', 'skimmer', 'cheese_grater', 'masher', 'squeezer', 'potato_peeler', 'lime_juicer', 'scissor', 'tray', 'baking_tray', 'pizza_tray', 'baking_pan', 'serving_board', 'eating_board', 'chopping_board', 'cut_board', 'cutting_board', 'board', 'pasta_strainer', 'strainer', 'mesh_strainer', 'can', 'beer_can', 'milk_can', 'canister', 'wine_bucket', 'bucket', 'plastic_bucket']
    num_classes = len(custom_classes)
    classifier = get_clip_embeddings(custom_classes)

    torch_model = build_model(cfg)
    DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS)
    torch_model.eval()
    print('huihui', torch_model.roi_heads.num_classes)
    reset_cls_test(torch_model, classifier, num_classes)
    # print('huihuii2', torch_model.roi_heads.num_classes)
    # exit()
    # convert and save model
    if args.export_method == "caffe2_tracing":
        sample_inputs = get_sample_inputs(args)
        exported_model = export_caffe2_tracing(cfg, torch_model, sample_inputs)
    elif args.export_method == "scripting":
        exported_model = export_scripting(torch_model)
    elif args.export_method == "tracing":
        sample_inputs = get_sample_inputs(args)
        exported_model = export_tracing(torch_model, sample_inputs)

    # run evaluation with the converted model
    if args.run_eval:
        assert exported_model is not None, (
            "Python inference is not yet implemented for "
            f"export_method={args.export_method}, format={args.format}."
        )
        logger.info("Running evaluation ... this takes a long time if you export to CPU.")
        dataset = cfg.DATASETS.TEST[0]
        data_loader = build_detection_test_loader(cfg, dataset)
        # NOTE: hard-coded evaluator. change to the evaluator for your dataset
        evaluator = COCOEvaluator(dataset, output_dir=args.output)
        metrics = inference_on_dataset(exported_model, data_loader, evaluator)
        print_csv_format(metrics)
    logger.info("Success.")

At line 253 custom_classes = ['scoop', .... is where I have added my custom labels.

Step 4) Now you need to execute this script with the command python3 detectron2/tools/deploy/export_model_lvis_vocabulary.py --config-file configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml --sample-image desk.jpg --output ./output --export-method tracing --format onnx MODEL.WEIGHTS models/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth MODEL.DEVICE cuda
from detic's root folder. This will save the ONNX model at output/model.onnx.

@gigasurgeon
Copy link
Author

gigasurgeon commented Nov 16, 2023

And to infer from the ONNX model, I am using @HtwoOtwo's script
from this comment ->#107 (comment)

The slightly modified inference script looks like this ->

import argparse
import cv2
import numpy as np
import onnxruntime as ort
import time


class Detic():
    def __init__(self, modelpath, detection_width=800, confThreshold=0.8):
        providers = ['CUDAExecutionProvider']
        self.session = ort.InferenceSession(modelpath, providers=providers)
        model_inputs = self.session.get_inputs()
        self.input_name = model_inputs[0].name
        self.max_size = detection_width
        self.confThreshold = confThreshold
        self.class_names = ['scoop', 'teaspoon', 'spoon', 'tea_spoon', 'flatware', 'tong', 'coffee_spoon', 'soupspoon', 'soup_spoon', 'spatula',
                            'ladle', 'skimmer', 'bowl', 'egg_bowl', 'sugar_bowl', 'washing_bowl', 'salad_bowl', 'cereal_bowl', 'soup_bowl', 'saucepan',
                            'frying_pan', 'pan', 'cake_pan', 'sauce_pan', 'content_pan', 'wok', 'saucer', 'plate', 'chinaware', 'glass', 'wine_glass',
                            'chalice', 'dixie_cup', 'flute_glass', 'shot_glass', 'wineglass', 'milk_bottle', 'bottle', 'water_bottle', 'wine_bottle',
                            'beer_bottle', 'tea_pot', 'pot', 'pressure_pot', 'pasta_pot', 'plastic_pot', 'sauce_pot', 'teapot', 'crock_pot', 'crockpot',
                            'cup', 'measuring_cup', 'coffee_cup', 'mug', 'teacup', 'tea_cup', 'pitcher', 'coffee_jar', 'sugar_jar', 'honey_jar', 'jar',
                            'jug', 'coffeepot', 'kettle', 'water_jug', 'urn', 'cream_pitcher', 'coffee_pot', 'container', 'lunch_box', 'sugar_container',
                            'milk_container', 'rice_container', 'sauce_container', 'food_container', 'casserole', 'knife', 'steak_knife', 'knife_sharpener',
                            'lime_squeezer', 'peeler', 'grater', 'skimmer', 'cheese_grater', 'masher', 'squeezer', 'potato_peeler', 'lime_juicer', 'scissor',
                            'tray', 'baking_tray', 'pizza_tray', 'baking_pan', 'serving_board', 'eating_board', 'chopping_board', 'cut_board', 'cutting_board',
                            'board', 'pasta_strainer', 'strainer', 'mesh_strainer', 'can', 'beer_can', 'milk_can', 'canister', 'wine_bucket', 'bucket', 'plastic_bucket']

        # self.assigned_colors = np.random.randint(0,high=256, size=(len(self.class_names), 3)).tolist()
        self.assigned_colors = np.random.randint(0,high=256, size=(4, 3)).tolist()

    def preprocess(self, srcimg):
        im_h, im_w, _ = srcimg.shape
        dstimg = cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB)
        if im_h < im_w:
            scale = self.max_size / im_h
            oh, ow = self.max_size, scale * im_w
        else:
            scale = self.max_size / im_w
            oh, ow = scale * im_h, self.max_size

        max_hw = max(oh, ow)
        if max_hw > self.max_size:
            scale = self.max_size / max_hw
            oh *= scale
            ow *= scale
        ow = int(ow + 0.5)
        oh = int(oh + 0.5)
        dstimg = cv2.resize(dstimg, (1067, 800))
        return dstimg

    def suppress_overlapping_bboxes(self, pred_boxes, scores, pred_classes, pred_masks):
        pred_boxes = pred_boxes.astype(np.int64)

        coord_str_dict = {}

        for i in range(pred_boxes.shape[0]):
            coord_str = f'{pred_boxes[i][0]}_{pred_boxes[i][1]}_{pred_boxes[i][2]}_{pred_boxes[i][3]}'

            if coord_str not in coord_str_dict:
                coord_str_dict[coord_str] = i
            else:
                if scores[i]>coord_str_dict[coord_str]:
                    coord_str_dict[coord_str] = i

        pred_boxes = np.array([pred_boxes[coord_str_dict[coord_str]] for coord_str in coord_str_dict])
        scores = np.array([scores[coord_str_dict[coord_str]] for coord_str in coord_str_dict])
        pred_classes = np.array([pred_classes[coord_str_dict[coord_str]] for coord_str in coord_str_dict])
        pred_masks = np.array([pred_masks[coord_str_dict[coord_str]] for coord_str in coord_str_dict])

        return pred_boxes, scores, pred_classes, pred_masks


    def post_processing(self, pred_boxes, scores, pred_classes, pred_masks, im_hw, pred_hw):
        scale_x, scale_y = (im_hw[1] / pred_hw[1], im_hw[0] / pred_hw[0])

        pred_boxes[:, 0::2] *= scale_x
        pred_boxes[:, 1::2] *= scale_y
        pred_boxes[:, [0, 2]] = np.clip(pred_boxes[:, [0, 2]], 0, im_hw[1])
        pred_boxes[:, [1, 3]] = np.clip(pred_boxes[:, [1, 3]], 0, im_hw[0])

        threshold = 0
        widths = pred_boxes[:, 2] - pred_boxes[:, 0]
        heights = pred_boxes[:, 3] - pred_boxes[:, 1]
        keep = (widths > threshold) & (heights > threshold)

        pred_boxes = pred_boxes[keep]
        scores = scores[keep]
        pred_classes = pred_classes[keep]
        pred_masks = pred_masks[keep]

        # mask_threshold = 0.5
        # pred_masks = paste_masks_in_image(
        #     pred_masks[:, 0, :, :], pred_boxes,
        #     (im_hw[0], im_hw[1]), mask_threshold
        # )
        threshold = 0.5
        idx = scores>threshold
        scores = scores[idx]
        pred_boxes = pred_boxes[idx]
        pred_classes = pred_classes[idx]
        pred_masks = pred_masks[idx]

        pred_boxes, scores, pred_classes, pred_masks = self.suppress_overlapping_bboxes(pred_boxes, scores, pred_classes, pred_masks)

        pred = {
            'pred_boxes': pred_boxes,
            'scores': scores,
            'pred_classes': pred_classes,
            'pred_masks': pred_masks,
        }

        # print(pred)
        # exit()
        return pred

    def draw_predictions(self, img, predictions):
        height, width = img.shape[:2]
        default_font_size = int(max(np.sqrt(height * width) // 90, 10))
        boxes = predictions["pred_boxes"].astype(np.int64)
        scores = predictions["scores"]
        # print(predictions["pred_classes"])
        # exit()
        classes_id = predictions["pred_classes"].tolist()
        # masks = predictions["pred_masks"].astype(np.uint8)
        num_instances = len(boxes)
        print('detect', num_instances, 'instances')

        for i in range(num_instances):
            x0, y0, x1, y1 = boxes[i]
            # color = self.assigned_colors[classes_id[i]]
            color = [0,255,0]
            cv2.rectangle(img, (x0, y0), (x1, y1), color=color,thickness=default_font_size // 4)
            # text = "{} {:.0f}%".format(self.class_names[classes_id[i]], round(scores[i],2) * 100)
            text = f"{x0}_{y0}_{x1}_{y1} {round(scores[i],2)} {self.class_names[classes_id[i]]}"
            print(text)
            cv2.putText(img, text, (x0, y0 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness=1, lineType=cv2.LINE_AA)
        return img

    def detect(self, srcimg):
        im_h, im_w = srcimg.shape[:2]
        dstimg = self.preprocess(srcimg)
        pred_hw = dstimg.shape[:2]
        input_image = dstimg.transpose(2, 0, 1).astype(np.float32)
        # input_image = np.expand_dims(dstimg.transpose(2, 0, 1), axis=0).astype(np.float32)

        # Inference
        pred_boxes, pred_classes, pred_masks, scores, _ = self.session.run(None, {self.input_name: input_image})
        # print(len(scores))
        # exit()
        preds = self.post_processing(pred_boxes, scores, pred_classes, pred_masks, (im_h, im_w), pred_hw)
        return preds


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--imgpath", default="desk.jpg", type=str, help="image path")
    parser.add_argument("--confThreshold", default=0.5, type=float, help='class confidence')
    parser.add_argument("--modelpath", type=str, default='onnx_models/model_custom_vocabulary.onnx', help="onnxmodel path")
    args = parser.parse_args()

    mynet = Detic(args.modelpath, confThreshold=args.confThreshold)
    srcimg = cv2.imread(args.imgpath)

    fpses = []

    for i in range(1):
        print(i)
        t1 = time.time()
        preds = mynet.detect(srcimg)
        t2 = time.time()
        fps = 1/(t2-t1)
        fpses.append(fps)
    avg_fps = sum(fpses)/len(fpses)
    print(f'avg_fps: {round(avg_fps, 2)}')
    result = mynet.draw_predictions(srcimg, preds)

    cv2.imwrite('result_onnx.jpg', result)

@antoniodecinque99
Copy link

Hello @gigasurgeon, thanks for the tutorial.
Would you be able to upload directly the onnx file you produced with the script?
Thank you so much

@gigasurgeon
Copy link
Author

Hello @gigasurgeon, thanks for the tutorial. Would you be able to upload directly the onnx file you produced with the script? Thank you so much

Here's the ONNX file -> https://drive.google.com/file/d/1hYz19lZk4ugLrUGO0HIP9M2RbXs5A4O-/view?usp=sharing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants