forked from SHI-Labs/OneFormer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
406 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import json | ||
import sys | ||
import os | ||
import time | ||
import numpy as np | ||
import cv2 | ||
import onnx | ||
import onnxruntime | ||
from onnx import numpy_helper | ||
|
||
model_dir ="./output" | ||
model=model_dir+"/model.onnx" | ||
path="../datasets/test_images/Bathroom1.jpg" # sys.argv[1] | ||
|
||
ade20k_info_path = "../datasets/ade20k_label_colors.txt" | ||
|
||
def read_ade20k_info(info_path=ade20k_info_path): | ||
with open(info_path) as fp: | ||
lines = fp.readlines() | ||
|
||
labels = [line[:-1].replace(';', ',').split(',')[0] for line in lines] | ||
colors = np.array([line[:-1].replace(';', ',').split(',')[-3:] for line in lines]).astype(np.int) | ||
|
||
return colors, labels | ||
|
||
colors, labels = read_ade20k_info() | ||
|
||
def util_draw_seg(seg_map, image, alpha = 0.5): | ||
|
||
# Convert segmentation prediction to colors | ||
color_segmap = cv2.resize(image, (seg_map.shape[1], seg_map.shape[0])) | ||
color_segmap[seg_map>0] = colors[seg_map[seg_map>0]] | ||
|
||
# Resize to match the image shape | ||
color_segmap = cv2.resize(color_segmap, (image.shape[1],image.shape[0])) | ||
|
||
# Fuse both images | ||
combined_img = None | ||
if alpha == 0: | ||
combined_img = np.hstack((image, color_segmap)) | ||
else: | ||
combined_img = cv2.addWeighted(image, alpha, color_segmap, (1-alpha),0) | ||
|
||
cv2.imwrite("./output/predictions.png", combined_img) | ||
return | ||
|
||
#Preprocess the image | ||
input_img = cv2.imread(path) | ||
#img = np.dot(img[...,:3], [0.299, 0.587, 0.114]) | ||
img = cv2.resize(input_img, dsize=(512, 512), interpolation=cv2.INTER_AREA) | ||
#img = img.transpose(2, 0, 1) | ||
img = img.astype("float32").transpose(2, 0, 1) | ||
#img.resize((1, 1, 28, 28)) | ||
|
||
data = np.array(img).astype('float32') | ||
session = onnxruntime.InferenceSession(model, None) | ||
input_name = session.get_inputs()[0].name | ||
output_name = session.get_outputs()[0].name | ||
print(input_name) | ||
print(output_name) | ||
|
||
result = session.run([output_name], {input_name: data}) | ||
prediction = np.argmax(np.array(result[0]).squeeze(), axis=0) | ||
#prediction=int(np.argmax(np.array(result).squeeze(), axis=0)) | ||
print(prediction.shape) | ||
print(prediction) | ||
|
||
util_draw_seg(prediction, input_img, 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
import argparse | ||
import os | ||
import sys | ||
from typing import Dict, List, Tuple | ||
import torch | ||
from torch import Tensor, nn | ||
|
||
import detectron2.data.transforms as T | ||
from detectron2.checkpoint import DetectionCheckpointer | ||
from detectron2.config import get_cfg | ||
from detectron2.projects.deeplab import add_deeplab_config | ||
from detectron2.data import build_detection_test_loader, detection_utils | ||
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format | ||
from detectron2.export import ( | ||
#STABLE_ONNX_OPSET_VERSION, | ||
TracingAdapter, | ||
dump_torchscript_IR, | ||
scripting_with_instances, | ||
) | ||
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model | ||
from detectron2.modeling.postprocessing import detector_postprocess | ||
from detectron2.projects.point_rend import add_pointrend_config | ||
from detectron2.structures import Boxes | ||
from detectron2.utils.env import TORCH_VERSION | ||
from detectron2.utils.file_io import PathManager | ||
from detectron2.utils.logger import setup_logger | ||
|
||
sys.path.insert(1, os.path.join(sys.path[0], '..')) | ||
# fmt: on | ||
|
||
from oneformer import ( | ||
add_oneformer_config, | ||
add_common_config, | ||
add_swin_config, | ||
add_dinat_config, | ||
add_convnext_config, | ||
) | ||
|
||
def setup_cfg(args): | ||
cfg = get_cfg() | ||
# cuda context is initialized before creating dataloader, so we don't fork anymore | ||
cfg.DATALOADER.NUM_WORKERS = 0 | ||
add_deeplab_config(cfg) | ||
add_common_config(cfg) | ||
add_swin_config(cfg) | ||
add_dinat_config(cfg) | ||
add_convnext_config(cfg) | ||
add_oneformer_config(cfg) | ||
cfg.merge_from_file(args.config_file) | ||
cfg.merge_from_list(args.opts) | ||
cfg.freeze() | ||
return cfg | ||
|
||
|
||
def export_caffe2_tracing(cfg, torch_model, inputs): | ||
from detectron2.export import Caffe2Tracer | ||
|
||
tracer = Caffe2Tracer(cfg, torch_model, inputs) | ||
if args.format == "caffe2": | ||
caffe2_model = tracer.export_caffe2() | ||
caffe2_model.save_protobuf(args.output) | ||
# draw the caffe2 graph | ||
caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=inputs) | ||
return caffe2_model | ||
elif args.format == "onnx": | ||
import onnx | ||
|
||
onnx_model = tracer.export_onnx() | ||
onnx.save(onnx_model, os.path.join(args.output, "model.onnx")) | ||
elif args.format == "torchscript": | ||
ts_model = tracer.export_torchscript() | ||
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: | ||
torch.jit.save(ts_model, f) | ||
dump_torchscript_IR(ts_model, args.output) | ||
|
||
|
||
# experimental. API not yet final | ||
def export_scripting(torch_model): | ||
assert TORCH_VERSION >= (1, 8) | ||
fields = { | ||
"proposal_boxes": Boxes, | ||
"objectness_logits": Tensor, | ||
"pred_boxes": Boxes, | ||
"scores": Tensor, | ||
"pred_classes": Tensor, | ||
"pred_masks": Tensor, | ||
"pred_keypoints": torch.Tensor, | ||
"pred_keypoint_heatmaps": torch.Tensor, | ||
} | ||
assert args.format == "torchscript", "Scripting only supports torchscript format." | ||
|
||
class ScriptableAdapterBase(nn.Module): | ||
# Use this adapter to workaround https://github.com/pytorch/pytorch/issues/46944 | ||
# by not retuning instances but dicts. Otherwise the exported model is not deployable | ||
def __init__(self): | ||
super().__init__() | ||
self.model = torch_model | ||
self.eval() | ||
|
||
if isinstance(torch_model, GeneralizedRCNN): | ||
|
||
class ScriptableAdapter(ScriptableAdapterBase): | ||
def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: | ||
instances = self.model.inference(inputs, do_postprocess=False) | ||
return [i.get_fields() for i in instances] | ||
|
||
else: | ||
|
||
class ScriptableAdapter(ScriptableAdapterBase): | ||
def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: | ||
instances = self.model(inputs) | ||
return [i.get_fields() for i in instances] | ||
|
||
ts_model = scripting_with_instances(ScriptableAdapter(), fields) | ||
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: | ||
torch.jit.save(ts_model, f) | ||
dump_torchscript_IR(ts_model, args.output) | ||
# TODO inference in Python now missing postprocessing glue code | ||
return None | ||
|
||
|
||
# experimental. API not yet final | ||
def export_tracing(torch_model, inputs): | ||
assert TORCH_VERSION >= (1, 8) | ||
image = inputs[0]["image"] | ||
inputs = [{"image": image}] # remove other unused keys | ||
|
||
if isinstance(torch_model, GeneralizedRCNN): | ||
|
||
def inference(model, inputs): | ||
# use do_postprocess=False so it returns ROI mask | ||
inst = model.inference(inputs, do_postprocess=False)[0] | ||
return [{"instances": inst}] | ||
|
||
else: | ||
inference = None # assume that we just call the model directly | ||
|
||
traceable_model = TracingAdapter(torch_model, inputs, inference) | ||
|
||
if args.format == "torchscript": | ||
ts_model = torch.jit.trace(traceable_model, (image,)) | ||
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: | ||
torch.jit.save(ts_model, f) | ||
dump_torchscript_IR(ts_model, args.output) | ||
elif args.format == "onnx": | ||
with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f: | ||
torch.onnx.export(traceable_model, (image,), f, opset_version=16, do_constant_folding=False, input_names=["input"], output_names=["output"])#STABLE_ONNX_OPSET_VERSION) | ||
logger.info("Inputs schema: " + str(traceable_model.inputs_schema)) | ||
logger.info("Outputs schema: " + str(traceable_model.outputs_schema)) | ||
|
||
if args.format != "torchscript": | ||
return None | ||
if not isinstance(torch_model, (GeneralizedRCNN, RetinaNet)): | ||
return None | ||
|
||
def eval_wrapper(inputs): | ||
""" | ||
The exported model does not contain the final resize step, which is typically | ||
unused in deployment but needed for evaluation. We add it manually here. | ||
""" | ||
input = inputs[0] | ||
instances = traceable_model.outputs_schema(ts_model(input["image"]))[0]["instances"] | ||
postprocessed = detector_postprocess(instances, input["height"], input["width"]) | ||
return [{"instances": postprocessed}] | ||
|
||
return eval_wrapper | ||
|
||
|
||
def get_sample_inputs(args): | ||
|
||
if args.sample_image is None: | ||
# get a first batch from dataset | ||
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) | ||
first_batch = next(iter(data_loader)) | ||
return first_batch | ||
else: | ||
# get a sample data | ||
original_image = detection_utils.read_image(args.sample_image, format=cfg.INPUT.FORMAT) | ||
# Do same preprocessing as DefaultPredictor | ||
aug = T.ResizeShortestEdge( | ||
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST | ||
) | ||
height, width = original_image.shape[:2] | ||
image = aug.get_transform(original_image).apply_image(original_image) | ||
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) | ||
|
||
inputs = {"image": image, "height": height, "width": width} | ||
|
||
# Sample ready | ||
sample_inputs = [inputs] | ||
return sample_inputs | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Export a model for deployment.") | ||
parser.add_argument( | ||
"--format", | ||
choices=["caffe2", "onnx", "torchscript"], | ||
help="output format", | ||
default="torchscript", | ||
) | ||
parser.add_argument( | ||
"--export-method", | ||
choices=["caffe2_tracing", "tracing", "scripting"], | ||
help="Method to export models", | ||
default="tracing", | ||
) | ||
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") | ||
parser.add_argument("--sample-image", default=None, type=str, help="sample image for input") | ||
parser.add_argument("--run-eval", action="store_true") | ||
parser.add_argument("--output", help="output directory for the converted model") | ||
parser.add_argument( | ||
"opts", | ||
help="Modify config options using the command-line", | ||
default=None, | ||
nargs=argparse.REMAINDER, | ||
) | ||
args = parser.parse_args() | ||
logger = setup_logger() | ||
logger.info("Command line arguments: " + str(args)) | ||
PathManager.mkdirs(args.output) | ||
# Disable re-specialization on new shapes. Otherwise --run-eval will be slow | ||
torch._C._jit_set_bailout_depth(1) | ||
|
||
cfg = setup_cfg(args) | ||
|
||
# create a torch model | ||
torch_model = build_model(cfg) | ||
DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS) | ||
torch_model.eval() | ||
|
||
# convert and save model | ||
if args.export_method == "caffe2_tracing": | ||
sample_inputs = get_sample_inputs(args) | ||
exported_model = export_caffe2_tracing(cfg, torch_model, sample_inputs) | ||
elif args.export_method == "scripting": | ||
exported_model = export_scripting(torch_model) | ||
elif args.export_method == "tracing": | ||
sample_inputs = get_sample_inputs(args) | ||
exported_model = export_tracing(torch_model, sample_inputs) | ||
|
||
# run evaluation with the converted model | ||
if args.run_eval: | ||
assert exported_model is not None, ( | ||
"Python inference is not yet implemented for " | ||
f"export_method={args.export_method}, format={args.format}." | ||
) | ||
logger.info("Running evaluation ... this takes a long time if you export to CPU.") | ||
dataset = cfg.DATASETS.TEST[0] | ||
data_loader = build_detection_test_loader(cfg, dataset) | ||
# NOTE: hard-coded evaluator. change to the evaluator for your dataset | ||
evaluator = COCOEvaluator(dataset, output_dir=args.output) | ||
metrics = inference_on_dataset(exported_model, data_loader, evaluator) | ||
print_csv_format(metrics) | ||
logger.info("Success.") |
Oops, something went wrong.