-
Notifications
You must be signed in to change notification settings - Fork 299
Closed
Description
I try to convert model from torch==1.6.0 to tensorflow==2.4.1 using onnx-tf==1.7.0.
My code is as follow:
def torchToTensorflow(self, model):
input = torch.randn([100, *image_size(DATASET_NAME)])
# 设置输入张量名,多个输入就是多个名
input_names = ["input"]
# 设置输出张量名
output_names = ["output"]
# 自定义onnx文件名和路径
onnx_filename = "model.onnx"
# 执行转化和保存
torch.onnx.export(model, input, onnx_filename, verbose=True, input_names=input_names,
output_names=output_names)
onnx_model = onnx.load("model.onnx") # load onnx model
tf_exp = prepare(onnx_model) # prepare tf representation
tf_exp.export_graph("model.pb") # export the model
print(tf_exp.tensor_dict)
with tf.Graph().as_default():
output_graph_def = tf.compat.v1.GraphDef()
output_graph_path = "model.pb/saved_model.pb"
with open(output_graph_path, 'rb') as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
return output_graph_defHowever I get following bug:
File "architecture_test.py", line 396, in torchToTensorflow
output_graph_def.ParseFromString(f.read())
google.protobuf.message.DecodeError: Error parsing message
Could you help me deal with the problem?
Metadata
Metadata
Assignees
Labels
No labels