In [4]:
import torch
import torchvision.models as models
import onnx
import onnxruntime as ort
import numpy as np
from PIL import Image
from torchvision import transforms

In [7]:
# 图像预处理函数
def preprocess(image_path):
    input_image = Image.open(image_path)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    return input_tensor.unsqueeze(0).numpy()

# 执行推理
def infer(image_path):
    input_data = preprocess(image_path)
    outputs = ort_session.run(None, {'input': input_data})
    return outputs


In [2]:
model = models.resnet18()
model.load_state_dict(torch.load('checkpoint\\resnet18-f37072fd.pth'))

<All keys matched successfully>

In [3]:
# 创建一个dummy输入用于导出模型
dummy_input = torch.randn(1, 3, 224, 224)
# 导出模型为ONNX格式
torch.onnx.export(model, dummy_input, "output\\resnet18.onnx", 
                  opset_version=11, 
                  input_names=['input'], 
                  output_names=['output'])

In [6]:
# 加载ONNX模型
onnx_model = onnx.load("output\\resnet18.onnx")
onnx.checker.check_model(onnx_model)

# 创建ONNX Runtime推理会话
ort_session = ort.InferenceSession("output\\resnet18.onnx")

In [9]:
# 示例推理
result = infer("data\\00000.jpg")
print("Inference result:", result)

Inference result: [array([[ 1.11059713e+00,  5.96479416e+00,  5.14497042e+00,
         2.13810420e+00,  5.48775673e+00,  1.65269542e+00,
         4.02171707e+00,  5.22111130e+00,  5.81609917e+00,
         1.19154348e+01, -1.87363851e+00,  2.48587918e+00,
         4.55490017e+00,  1.21734226e+00, -4.26583827e-01,
        -1.33086133e+00,  1.99815929e-01,  5.35940886e-01,
        -1.44945312e+00, -5.32699645e-01,  3.32774568e+00,
         6.82639599e-01,  3.66255689e+00,  6.26690960e+00,
        -2.01597166e+00, -3.75951672e+00,  8.17122519e-01,
         1.53925061e+00, -1.05306184e+00,  3.48280525e+00,
        -5.08167744e-01, -1.36518979e+00,  4.54099536e-01,
         3.76267314e-01,  2.10085106e+00,  1.18279076e+00,
         3.56633997e+00, -4.34761047e-01,  2.23057675e+00,
         7.87392437e-01,  1.14518702e+00,  1.52058637e+00,
         1.64700544e+00,  1.86708999e+00,  1.76150537e+00,
         2.56161594e+00,  9.70054626e-01, -1.38508654e+00,
         3.41672635e+00,  4.82998037e