In [1]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import cv2
import torch
import numpy as np
import onnxruntime as ort
import matplotlib.pyplot as plt

device='cpu'
classes={1:'fire',2:'smoke'}

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)

WEIGHTS_FILE = r"D:\nn\code\ckpt\faster_rcnn_state_2.pth"

num_classes = 3

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# Load the traines weights
model.load_state_dict(torch.load(WEIGHTS_FILE, map_location='cpu'))

model.to(device)

model.eval()

image = torch.rand([1, 3, 416, 416])




## 固定输入

In [16]:
# onnx算子 https://github.com/onnx/onnx/blob/main/docs/Operators.md
# 输入尺寸固定
torch.onnx.export(model, 
                  image, 
                  'fasterrcnn.onnx',
                  opset_version = 11)

## 动态输入

In [None]:
input_name = 'input'
output_name = 'output'
# 输入尺寸不固定
torch.onnx.export(model, 
                 image, 
                 "fasterrcnn.onnx",
                 opset_version=11,
                 input_names=[input_name],
                 output_names=[output_name],
                 dynamic_axes={
                     input_name: {0: 'batch_size', 2: 'in_width', 3: 'int_height'},
                     output_name: {0: 'batch_size', 2: 'out_width', 3: 'out_height'}}
                 )


## onnx推理

In [6]:
class Onnx_Module(ort.InferenceSession):
    ''' onnx 推理模型
        provider: 优先使用 GPU'''
    provider = ort.get_available_providers()[
        1 if ort.get_device() == 'GPU' else 0]
 
    def __init__(self, file):
        super(Onnx_Module, self).__init__(file, providers=[self.provider])
        # 参考: ort.NodeArg
        self.inputs = [node_arg.name for node_arg in self.get_inputs()]
        self.outputs = [node_arg.name for node_arg in self.get_outputs()]
 
    def __call__(self, *arrays):
        input_feed = {name: x for name, x in zip(self.inputs, arrays)}
        return self.run(self.outputs, input_feed)

In [7]:
model = Onnx_Module('fasterrcnn.onnx')

imgpath=r'D:\nn\code\datasets\fire\test\20230322101540.jpg'
img = cv2.imread(imgpath, cv2.IMREAD_COLOR)
# img=cv2.resize(img,(416,416))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)


img /= 255.0
img = torch.from_numpy(img)
img = img.unsqueeze(0)
img = img.permute(0,3,1,2).numpy()

res=model(img)
print(res)

[array([[306.73575, 361.15707, 544.3265 , 477.9342 ],
       [290.35562, 260.72772, 324.5933 , 456.56262],
       [862.0033 , 118.4404 , 881.1725 , 141.63399]], dtype=float32), array([1, 1, 1], dtype=int64), array([0.9768057, 0.9387695, 0.7504843], dtype=float32)]


In [None]:
plt.figure(figsize=(20,20))
classes={1:'fire',2:'smoke'}
sample = cv2.imread(imgpath, cv2.IMREAD_COLOR)
# sample=cv2.resize(sample,(416,416))
sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB).astype(np.int32)

boxs,names,scores = res
boxs=boxs.astype(np.int32)
for i,box in enumerate(boxs):
    cv2.rectangle(img=sample,
                  pt1=(box[0], box[1]),
                  pt2=(box[2], box[3]),
                  color=(0, 220, 0), 
                  thickness=2)
    cv2.putText(img=sample, 
                text=classes[names[i]], 
                org=(box[0],box[1]-5),
                fontFace=cv2.FONT_HERSHEY_COMPLEX,
                fontScale=0.7,
                color=(220,0,0),
                thickness=1,
                lineType=cv2.LINE_AA)  

plt.axis('off')
plt.imshow(sample)