In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import onnx
import os
from onnx import optimizer

# Preprocessing: load the model contains two transposes.
model_path = os.path.join('resources', 'two_transposes.onnx')
original_model = onnx.load(model_path)

print('The model before optimization:\n{}'.format(original_model))

The model before optimization:
ir_version: 3
producer_name: "onnx-examples"
graph {
  node {
    input: "X"
    output: "Y"
    op_type: "Transpose"
    attribute {
      name: "perm"
      ints: 1
      ints: 0
      ints: 2
      type: INTS
    }
  }
  node {
    input: "Y"
    output: "Z"
    op_type: "Transpose"
    attribute {
      name: "perm"
      ints: 1
      ints: 0
      ints: 2
      type: INTS
    }
  }
  name: "two-transposes"
  input {
    name: "X"
    type {
      tensor_type {
        elem_type: FLOAT
        shape {
          dim {
            dim_value: 2
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 4
          }
        }
      }
    }
  }
  output {
    name: "Z"
    type {
      tensor_type {
        elem_type: FLOAT
        shape {
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 2
          }
          dim {
            dim_value: 4
          }
        }
    

In [2]:
# A full list of supported optimization passes can be found here:
# https://github.com/onnx/onnx/blob/master/onnx/optimizer.py#L21
passes = ['fuse_consecutive_transposes']

# Apply the optimization on the original serialized model
optimized_model = optimizer.optimize(original_model, passes)

print('The model after optimization:\n{}'.format(optimized_model))

The model after optimization:
ir_version: 3
producer_name: "onnx-examples"
graph {
  node {
    input: "X"
    output: "Z"
    op_type: "Transpose"
    attribute {
      name: "perm"
      ints: 0
      ints: 1
      ints: 2
      type: INTS
    }
  }
  name: "two-transposes"
  input {
    name: "X"
    type {
      tensor_type {
        elem_type: FLOAT
        shape {
          dim {
            dim_value: 2
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 4
          }
        }
      }
    }
  }
  output {
    name: "Z"
    type {
      tensor_type {
        elem_type: FLOAT
        shape {
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 2
          }
          dim {
            dim_value: 4
          }
        }
      }
    }
  }
}
opset_import {
  version: 6
}

