In [None]:
import onnx_graphsurgeon  as gs
import onnx

In [None]:
def make_input_multidimensional(graph, input_name:str):
    # try to get the mask input 
    has_mask_input = [x for x in graph.inputs if x.name==input_name][0]
    # change it to 4 dims 
    has_mask_input.shape = [1, 1, 1, 1]
    
def add_cast_before_outputs(graph, output_names:list, dtype=onnx.TensorProto.FLOAT):

    assert isinstance(output_names, list)
    # Process each output name in the provided list
    for output_name in output_names:
        # Get the specific output to be casted
        target_output = [x for x in graph.outputs if x.name == output_name][0]
        
        # Make a new tensor variable, that will be using fp32
        new_output = gs.Variable(name=f"new_{output_name}", shape=target_output.shape, dtype=dtype)
        
        # Make a new node, which performs the type casting
        new_casting_node = gs.Node("Cast", f"cast_{output_name}", attrs={"to": dtype}, inputs=[target_output], outputs=[new_output])
        graph.nodes.append(new_casting_node)
        
        # Adjust the outputs of the graph
        graph.outputs.append(new_output)
        graph.outputs.remove(target_output)
        
        # Remove the tensor
        target_output.name = f"{output_name}_outdated"
        
        # Rename the new output to the original name
        new_output.name = output_name
    
    # Cleanup and sort the graph
    cleaned_graph = graph.cleanup().toposort()

    return cleaned_graph

In [None]:
graph = gs.import_onnx(onnx.load("onnx/decoder.onnx"))
make_input_multidimensional(graph, "has_mask_input")
clean_graph = add_cast_before_outputs(graph=graph,  output_names=["iou_predictions", "low_res_masks" ])
output_path="onnx/casting_decoder.onnx"
# Save the modified ONNX model
onnx.save(gs.export_onnx(clean_graph), output_path)

In [None]:
graph = gs.import_onnx(onnx.load("onnx/decoder.onnx"))

In [None]:
graph.inputs