Skip to content

How to convert from pytorch to tensorflow==2.4.1 #892

@chenyuqi990215

Description

@chenyuqi990215

I try to convert model from torch==1.6.0 to tensorflow==2.4.1 using onnx-tf==1.7.0.

My code is as follow:

    def torchToTensorflow(self, model):
        input = torch.randn([100, *image_size(DATASET_NAME)])
        # 设置输入张量名,多个输入就是多个名
        input_names = ["input"]
        # 设置输出张量名
        output_names = ["output"]
        # 自定义onnx文件名和路径
        onnx_filename = "model.onnx"
        # 执行转化和保存
        torch.onnx.export(model, input, onnx_filename, verbose=True, input_names=input_names,
                          output_names=output_names)
        onnx_model = onnx.load("model.onnx")  # load onnx model
        tf_exp = prepare(onnx_model)  # prepare tf representation
        tf_exp.export_graph("model.pb")  # export the model
        print(tf_exp.tensor_dict)
        with tf.Graph().as_default():
            output_graph_def = tf.compat.v1.GraphDef()
            output_graph_path = "model.pb/saved_model.pb"

            with open(output_graph_path, 'rb') as f:
                output_graph_def.ParseFromString(f.read())
                _ = tf.import_graph_def(output_graph_def, name="")
        return output_graph_def

However I get following bug:

  File "architecture_test.py", line 396, in torchToTensorflow
    output_graph_def.ParseFromString(f.read())
google.protobuf.message.DecodeError: Error parsing message

Could you help me deal with the problem?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions