-
Notifications
You must be signed in to change notification settings - Fork 1
/
vgg.py
34 lines (22 loc) · 890 Bytes
/
vgg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torchvision
#dummy_input = torch.randn(10, 3, 224, 224, device='cpu')
img_size = (224, 224)
dummy_input = torch.zeros((1, 3) + img_size)
model = torchvision.models.vgg11(pretrained=True).cpu()
input_names = ["actual_input_1"] + ["learned_%d" % i for i in range(16) ]
output_names = ["output1"]
print(input_names)
print(output_names)
torch.onnx.export(model, dummy_input, "./weights/vgg11.onnx", verbose=True, input_names=input_names, output_names=output_names,opset_version=11)
print('---1---')
import onnx
model = onnx.load("./weights/vgg11.onnx")
onnx.checker.check_model(model)
onnx.helper.printable_graph(model.graph)
print('---2---')
import onnxruntime as ort
import numpy as np
ort_session = ort.InferenceSession('./weights/vgg11.onnx')
outputs = ort_session.run(None, {"actual_input_1":np.zeros((1,3)+(224,224)).astype(np.float32)})
print(outputs[0])