## 模型推理

In [2]:
import onnxruntime as ort


class Onnx_Module(ort.InferenceSession):
    """
    onnx 推理模型
    provider: 优先使用 GPU
    """

    provider = ort.get_available_providers()[1 if ort.get_device() == "GPU" else 0]

    def __init__(self, file):
        super(Onnx_Module, self).__init__(file, providers=[self.provider])
        # 参考: ort.NodeArg
        self.inputs = [node.name for node in self.get_inputs()]
        self.outputs = [node.name for node in self.get_outputs()]

    def __call__(self, *arrays):
        input_feed = {name: x for name, x in zip(self.inputs, arrays)}
        return self.run(self.outputs, input_feed)

## FasterRCNN

In [5]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

device = "cpu"
classes = {1: "fire", 2: "smoke"}

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    weights=None, weights_backbone=None
)

num_classes = 3
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

WEIGHTS_FILE = r"D:\nn\code\ckpt\faster_rcnn_state_2.pth"
model.load_state_dict(torch.load(WEIGHTS_FILE, map_location="cpu"))
model.to(device)
model.eval()

image = torch.rand([1, 3, 416, 416])

## 固定输入

In [None]:
# onnx算子 https://github.com/onnx/onnx/blob/main/docs/Operators.md
# 输入尺寸固定
torch.onnx.export(model, image, "fasterrcnn.onnx", opset_version=11)

## 动态输入

In [None]:
input_name = "input"
output_name = "output"
# https://blog.csdn.net/LimitOut/article/details/107117759
torch.onnx.export(
    model,
    image,
    "fasterrcnn.onnx",
    opset_version=11,
    input_names=[input_name],
    output_names=[output_name],
    dynamic_axes={
        input_name: {0: "batch_size", 2: "in_width", 3: "int_height"},
        output_name: {0: "batch_size", 2: "out_width", 3: "out_height"},
    },
)

In [None]:
import cv2
import numpy as np


def read_img(imgpath, size=(640, 640)):
    img = cv2.imread(imgpath, cv2.IMREAD_COLOR)
    img = cv2.resize(img, size)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    imdata = img.astype(np.float32) / 255.0
    # 加上batch_size
    imdata = np.expand_dims(imdata, axis=0)
    # 将通道转到第二维
    imdata = np.transpose(imdata, (0, 3, 1, 2))
    return img, imdata

In [None]:
import matplotlib.pyplot as plt
from box import Box

model = Onnx_Module("fasterrcnn.onnx")

imgpath = r"4.jpg"

img, im_data = read_img(imgpath)

# im = torch.from_numpy(im)
# 加上batch_size
# im = im.unsqueeze(0)
# 将通道转到第二维
# im = im.permute(0, 3, 1, 2).numpy()

res = model(im_data)
boxes = [Box(*res[0][i], classes[res[1][i]]) for i in range(len(res[0]))]

for box in boxes:
    rect = plt.Rectangle(
        xy=(box.xmin, box.ymin),
        width=box.xmax - box.xmin,
        height=box.ymax - box.ymin,
        edgecolor="r",
        linewidth=1,
        fill=False,
    )
    plt.text(
        x=box.xmin,
        y=box.ymin,
        s=box.name,
        fontsize=10,
        color="r",
        style="italic",
        weight="light",
    )
    plt.gca().add_patch(rect)
plt.imshow(img)
plt.show()

## yolov5

In [None]:
CLASSES = [
    "fire",
]  # coco80类别


def pynms(dets, thresh):  # 非极大抑制
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    areas = (y2 - y1 + 1) * (x2 - x1 + 1)
    scores = dets[:, 4]
    keep = []
    index = scores.argsort()[::-1]  # 置信度从大到小排序（下标）

    while index.size > 0:
        i = index[0]
        keep.append(i)

        x11 = np.maximum(x1[i], x1[index[1:]])  # 计算相交面积
        y11 = np.maximum(y1[i], y1[index[1:]])
        x22 = np.minimum(x2[i], x2[index[1:]])
        y22 = np.minimum(y2[i], y2[index[1:]])

        w = np.maximum(0, x22 - x11 + 1)  # 当两个框不想交时x22 - x11或y22 - y11 为负数，
        # 两框不相交时把相交面积置0
        h = np.maximum(0, y22 - y11 + 1)  #

        overlaps = w * h
        ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)  # 计算IOU

        idx = np.where(ious <= thresh)[0]  # IOU小于thresh的框保留下来
        index = index[idx + 1]  # 下标以1开始
        print(index)

    return keep


def xywh2xyxy(x):
    # [x, y, w, h] to [x1, y1, x2, y2]
    y = np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2
    y[:, 1] = x[:, 1] - x[:, 3] / 2
    y[:, 2] = x[:, 0] + x[:, 2] / 2
    y[:, 3] = x[:, 1] + x[:, 3] / 2
    return y


def filter_box(org_box, conf_thres, iou_thres):  # 过滤掉无用的框
    org_box = np.squeeze(org_box)  # 删除为1的维度
    conf = org_box[..., 4] > conf_thres  # 删除置信度小于conf_thres的BOX
    # print(conf)
    box = org_box[conf == True]
    cls_cinf = box[..., 5:]
    cls = []
    for i in range(len(cls_cinf)):
        cls.append(int(np.argmax(cls_cinf[i])))
    all_cls = list(set(cls))  # 删除重复的类别
    output = []
    for i in range(len(all_cls)):
        curr_cls = all_cls[i]
        curr_cls_box = []
        curr_out_box = []
        for j in range(len(cls)):
            if cls[j] == curr_cls:
                box[j][5] = curr_cls  # 将第6列元素替换为类别下标
                curr_cls_box.append(box[j][:6])  # 当前类别的BOX
        curr_cls_box = np.array(curr_cls_box)
        curr_cls_box = xywh2xyxy(curr_cls_box)
        curr_out_box = pynms(curr_cls_box, iou_thres)  # 经过非极大抑制后输出的BOX下标
        for k in curr_out_box:
            output.append(curr_cls_box[k])  # 利用下标取出非极大抑制后的BOX
    output = np.array(output)
    return output


def draw(image, box_data):  # 画图
    boxes = box_data[..., :4].astype(np.int32)  # 取整方便画框
    scores = box_data[..., 4]
    classes = box_data[..., 5].astype(np.int32)  # 下标取整

    for box, score, cl in zip(boxes, scores, classes):
        top, left, right, bottom = box
        print("class: {}, score: {}".format(CLASSES[cl], score))
        print(
            "box coordinate left,top,right,down: [{}, {}, {}, {}]".format(
                top, left, right, bottom
            )
        )

        cv2.rectangle(image, (top, left), (right, bottom), (255, 0, 0), 2)
        cv2.putText(
            image,
            "{0} {1:.2f}".format(CLASSES[cl], score),
            (top, left),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.6,
            (0, 0, 255),
            2,
        )


model = Onnx_Module(r"yolov5.onnx")
imgpath = r"4.jpg"
img, im_data = read_img(imgpath)
output = model(im_data)

outbox = filter_box(output, 0.5, 0.5)
draw(img, outbox)
plt.imshow(img)
plt.show()