In [34]:
import torch
from model.model import parsingNet
from torch.onnx import export
import onnx
import onnxruntime
import numpy as np
import scipy
import cv2
from data.constant import culane_row_anchor

In [6]:
net = parsingNet(pretrained = False, backbone='18',cls_dim = (201,18,4),use_aux=False) # we dont need auxiliary segmentation in testing

state_dict = torch.load('checkpoints/culane_18.pth', map_location='cpu')['model']
compatible_state_dict = {}
for k, v in state_dict.items():
  if 'module.' in k:
    compatible_state_dict[k[7:]] = v
  else:
    compatible_state_dict[k] = v

net.load_state_dict(compatible_state_dict, strict=False)
# print(net)
# 设置模型为推理模式
net.eval()

x = torch.randn(1, 3, 288, 800)
torch_out = net(x)



In [7]:
export(net, x, 'culane_18.onnx', verbose=True, input_names=['input'], output_names=['output'], opset_version=11)

In [8]:
onnx_model = onnx.load('culane_18.onnx')

In [11]:
# onnx_model = onnx.load("test.onnx")
onnx.checker.check_model(onnx_model)

ort_session = onnxruntime.InferenceSession("culane_18.onnx")
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Exported model has been tested with ONNXRuntime, and the result looks good!


In [43]:
from PIL import Image
import torchvision.transforms as transforms

img_path = 'I:/Ultra-Fast-Lane-Detection/CULane/driver_37_30frame/05181432_0203.MP4/04320.jpg'

img = Image.open(img_path)
# print(type(img))
# image = np.asarray(img, dtype=np.float32)
# image = np.transpose(image, (2, 0, 1))

img_transforms = transforms.Compose([
        transforms.Resize((288, 800)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])


image = img_transforms(img).unsqueeze(0)
print(image.shape)

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(image)}
ort_outs = ort_session.run(None, ort_inputs)

ort_outs = ort_outs[0][0]
# print(ort_outs)

print(ort_outs.shape)

img_w, img_h = 1640, 590
row_anchor = culane_row_anchor
cls_num_per_lane = 18

col_sample = np.linspace(0, 800 - 1, 200)
col_sample_w = col_sample[1] - col_sample[0]

ort_outs = ort_outs[:, ::-1, :]
prob = scipy.special.softmax(ort_outs[:-1, :, :], axis=0)
idx = np.arange(200) + 1
idx = idx.reshape(-1, 1, 1)
loc = np.sum(prob * idx, axis=0)
out_j = np.argmax(ort_outs, axis=0)
loc[out_j == 200] = 0
out_j = loc

# import pdb; pdb.set_trace()
vis = cv2.imread(img_path)
for i in range(out_j.shape[1]):
    if np.sum(out_j[:, i] != 0) > 2:
        for k in range(out_j.shape[0]):
            if out_j[k, i] > 0:
                ppp = (int(out_j[k, i] * col_sample_w * img_w / 800) - 1, int(img_h * (row_anchor[cls_num_per_lane-1-k]/288)) - 1 )
                cv2.circle(vis,ppp,5,(0,255,0),-1)
cv2.imwrite('result.jpg', vis)

torch.Size([1, 3, 288, 800])
(201, 18, 4)


True