In [112]:
import torch
import onnx
from onnx import helper, numpy_helper
import onnxruntime as ort
import numpy as np

### Implementation of the early exit strategy

In [113]:
model = onnx.load("trained_model.onnx") 

In [None]:
print("Entradas:", [input.name for input in model.graph.input])
print("Salidas:", [output.name for output in model.graph.output])

# inspect nodes to identify the early exit
for node in model.graph.node:
    print(node.name, node.op_type, [i for i in node.output])


Entradas: ['input']
Salidas: ['early_exit', 'final_exit', '92', '95']
/conv1/Conv Conv ['/conv1/Conv_output_0']
/Relu Relu ['/Relu_output_0']
/conv2/Conv Conv ['/conv2/Conv_output_0']
/Relu_1 Relu ['/Relu_1_output_0']
/pool1/MaxPool MaxPool ['/pool1/MaxPool_output_0']
/conv3/Conv Conv ['/conv3/Conv_output_0']
/Relu_2 Relu ['/Relu_2_output_0']
/conv4/Conv Conv ['/conv4/Conv_output_0']
/Relu_3 Relu ['/Relu_3_output_0']
/pool2/MaxPool MaxPool ['/pool2/MaxPool_output_0']
/obj_detect_conv/Conv Conv ['/obj_detect_conv/Conv_output_0']
/Relu_4 Relu ['/Relu_4_output_0']
/Flatten Flatten ['/Flatten_output_0']
/obj_detect_fc_ee/Gemm Gemm ['early_exit']
/conv5/Conv Conv ['/conv5/Conv_output_0']
/Relu_5 Relu ['/Relu_5_output_0']
/conv6/Conv Conv ['/conv6/Conv_output_0']
/Relu_6 Relu ['/Relu_6_output_0']
/pool3/MaxPool MaxPool ['/pool3/MaxPool_output_0']
/Flatten_1 Flatten ['/Flatten_1_output_0']
/obj_detect_fc_final/Gemm Gemm ['final_exit']
/binary_fc1/Gemm Gemm ['/binary_fc1/Gemm_output_0']
/Relu_

In [None]:
gemm_weights = np.random.rand(4, 256).astype(np.float32)  # adjust dimensions to  match the original model
gemm_biases = np.random.rand(4).astype(np.float32)

# convert to ONNX initializers
gemm_weights_initializer = numpy_helper.from_array(gemm_weights, name="/obj_detect_fc_final/Gemm_weight")
gemm_biases_initializer = numpy_helper.from_array(gemm_biases, name="/obj_detect_fc_final/Gemm_bias")

# add to the graph
model.graph.initializer.extend([gemm_weights_initializer, gemm_biases_initializer])


In [None]:
weight_tensor = helper.make_tensor(
    name="/obj_detect_fc_ee/Gemm_weight",
    data_type=onnx.TensorProto.FLOAT,
    dims=[256, 4],  # correct dimensions for the Gemm layer
    vals=np.random.rand(256 * 4).astype(np.float32)  
)

bias_tensor = helper.make_tensor(
    name="/obj_detect_fc_ee/Gemm_bias",
    data_type=onnx.TensorProto.FLOAT,
    dims=[4],  # correct dimensions for the bias
    vals=np.random.rand(4).astype(np.float32)  
)

model.graph.initializer.extend([weight_tensor, bias_tensor])


In [None]:
threshold_value = 0.9  # confidence threshold to be adapted
threshold_tensor = numpy_helper.from_array(np.array([threshold_value], dtype=np.float32), name="threshold")
model.graph.initializer.append(threshold_tensor)

reduce_max_node = helper.make_node(
    "ReduceMax",
    inputs=["early_exit"],
    outputs=["max_confidence"],
    name="reduce_max_node",
    keepdims=0
)

greater_node = helper.make_node(
    "Greater",
    inputs=["max_confidence", "threshold"],
    outputs=["early_exit_condition"],
    name="greater_node"
)

model.graph.node.extend([reduce_max_node, greater_node])


In [118]:
then_branch = helper.make_graph(
    nodes=[
        helper.make_node(
            "Gemm",
            inputs=["/Flatten_output_0", "/obj_detect_fc_ee/Gemm_weight", "/obj_detect_fc_ee/Gemm_bias"],
            outputs=["output_early_exit"],
            name="/obj_detect_fc_ee/Gemm"
        )
    ],
    name="then_branch",
    inputs=[
        helper.make_tensor_value_info("/Flatten_output_0", onnx.TensorProto.FLOAT, [1, 256])
    ],
    outputs=[
        helper.make_tensor_value_info("output_early_exit", onnx.TensorProto.FLOAT, [1, 4])
    ]
)


In [None]:
else_branch = helper.make_graph(
    nodes=[
        helper.make_node(
            "Gemm",
            inputs=["/Flatten_1_output_0", "/obj_detect_fc_final/Gemm_weight", "/obj_detect_fc_final/Gemm_bias"],
            outputs=["output_final_exit"],
            name="/obj_detect_fc_final/Gemm"
        )
    ],
    name="else_branch",
    inputs=[
        helper.make_tensor_value_info("/Flatten_1_output_0", onnx.TensorProto.FLOAT, [1, 256])  
    ],
    outputs=[
        helper.make_tensor_value_info("output_final_exit", onnx.TensorProto.FLOAT, [1, 4])  
    ]
)


In [120]:
if_node = helper.make_node(
    "If",
    inputs=["early_exit_condition"],
    outputs=["final_exit_if"],
    then_branch=then_branch,
    else_branch=else_branch,
    name="if_node"
)

model.graph.node.append(if_node)


input: "early_exit_condition"
output: "final_exit_if"
name: "if_node"
op_type: "If"
attribute {
  name: "else_branch"
  g {
    node {
      input: "/Flatten_1_output_0"
      input: "/obj_detect_fc_final/Gemm_weight"
      input: "/obj_detect_fc_final/Gemm_bias"
      output: "output_final_exit"
      name: "/obj_detect_fc_final/Gemm"
      op_type: "Gemm"
    }
    name: "else_branch"
    input {
      name: "/Flatten_1_output_0"
      type {
        tensor_type {
          elem_type: 1
          shape {
            dim {
              dim_value: 1
            }
            dim {
              dim_value: 256
            }
          }
        }
      }
    }
    output {
      name: "output_final_exit"
      type {
        tensor_type {
          elem_type: 1
          shape {
            dim {
              dim_value: 1
            }
            dim {
              dim_value: 4
            }
          }
        }
      }
    }
  }
  type: GRAPH
}
attribute {
  name: "then_branch"
  g

In [None]:
# check graph execution is valid
onnx.checker.check_model(model)