In [33]:
!pip install onnx onnxruntime



In [1]:
import onnx
import onnxruntime as ort
import numpy as np

# Define input and output tensor names
input1_name = "X"
flatten_output_name = "Y"

# Create the ONNX model with Flatten operator
def create_flatten_model(axis, input_rank, output_shape, dtype):

    #Create input tensor
    input1 = onnx.helper.make_tensor_value_info(input1_name, dtype, input_rank)

    # Create output tensor (final result after flatten operation)
    flatten_output = onnx.helper.make_tensor_value_info(flatten_output_name, dtype, output_shape)

    # Define flatten node
    flatten_node = onnx.helper.make_node(
        "Flatten",
        inputs=[input1_name],
        outputs=[flatten_output_name],
        axis=axis
    )

    # Create the ONNX graph
    graph_def = onnx.helper.make_graph(
        [flatten_node],
        "Flatten",
        [input1],
        [flatten_output],
    )

    # Create the ONNX model
    model = onnx.helper.make_model(graph_def, opset_imports=[onnx.helper.make_opsetid("", 22)]) # Explicitly set opset to 22
    model.ir_version = 10 
    onnx.checker.check_model(model)

    # Save the model
    onnx.save(model, "flatten.onnx")

    # Load and run the model with ONNX Runtime
    session = ort.InferenceSession("flatten.onnx")
    return session

def do_flatten(x, session):
    # Run inference
    output = session.run(None, {input1_name: x})

    x_f = (np.array2string(x, separator=',', max_line_width=np.inf).replace('\n', ''))
    y_f = (np.array2string(output[0], separator=',', max_line_width=np.inf).replace('\n', ''))

    # Display results
    print("Shape of input tensor:", x.shape)
    print(f"X={x_f}")
    print("Shape of output tensor:", output[0].shape)
    print(f"Result = {y_f}")


np.set_printoptions(precision=None, floatmode='fixed')

## Nominal Cases

In [2]:
# Case N1: 3-rank tensor (int32), axis=0
onnx_type = onnx.TensorProto.INT32
x = np.arange(24).reshape(2, 3, 4).astype(np.int32)
axis = 0
input_shape = x.shape  
dim0 = np.prod(input_shape[:axis])    
dim1 = np.prod(input_shape[axis:])
output_shape = [int(dim0), int(dim1)]
session = create_flatten_model(axis, [None,None,None], output_shape, onnx_type)
do_flatten(x, session)

Shape of input tensor: (2, 3, 4)
X=[[[ 0, 1, 2, 3],  [ 4, 5, 6, 7],  [ 8, 9,10,11]], [[12,13,14,15],  [16,17,18,19],  [20,21,22,23]]]
Shape of output tensor: (1, 24)
Result = [[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]]


In [3]:
# Case N2: 3-rank tensor (int32), axis=1
onnx_type = onnx.TensorProto.INT32
x = np.arange(24).reshape(2, 3, 4).astype(np.int32)
axis = 1
input_shape = x.shape  
dim0 = np.prod(input_shape[:axis])    
dim1 = np.prod(input_shape[axis:])
output_shape = [int(dim0), int(dim1)]
session = create_flatten_model(axis, [None,None,None], output_shape, onnx_type)
do_flatten(x, session)

Shape of input tensor: (2, 3, 4)
X=[[[ 0, 1, 2, 3],  [ 4, 5, 6, 7],  [ 8, 9,10,11]], [[12,13,14,15],  [16,17,18,19],  [20,21,22,23]]]
Shape of output tensor: (2, 12)
Result = [[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11], [12,13,14,15,16,17,18,19,20,21,22,23]]


In [None]:
# Case N3: 3-rank tensor (int32), axis=2
onnx_type = onnx.TensorProto.INT32
x = np.arange(24).reshape(2, 3, 4).astype(np.int32)
axis = 2
input_shape = x.shape  
dim0 = np.prod(input_shape[:axis])    
dim1 = np.prod(input_shape[axis:])
output_shape = [int(dim0), int(dim1)]
session = create_flatten_model(axis, [None,None,None], output_shape, onnx_type)
do_flatten(x, session)

Shape of input tensor: (2, 3, 4)
X=[[[ 0, 1, 2, 3],  [ 4, 5, 6, 7],  [ 8, 9,10,11]], [[12,13,14,15],  [16,17,18,19],  [20,21,22,23]]]
Shape of output tensor: (6, 4)
Result = [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9,10,11], [12,13,14,15], [16,17,18,19], [20,21,22,23]]


In [13]:
# Case N4: 3-rank tensor (int32), axis=3
onnx_type = onnx.TensorProto.INT32
x = np.arange(24).reshape(2, 3, 4).astype(np.int32)
axis = 3
input_shape = x.shape  
dim0 = np.prod(input_shape[:axis])    
dim1 = np.prod(input_shape[axis:])
output_shape = [int(dim0), int(dim1)]
session = create_flatten_model(axis, [None,None,None], output_shape, onnx_type)
do_flatten(x, session)

Shape of input tensor: (2, 3, 4)
X=[[[ 0, 1, 2, 3],  [ 4, 5, 6, 7],  [ 8, 9,10,11]], [[12,13,14,15],  [16,17,18,19],  [20,21,22,23]]]
Shape of output tensor: (24, 1)
Result = [[ 0], [ 1], [ 2], [ 3], [ 4], [ 5], [ 6], [ 7], [ 8], [ 9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23]]


In [14]:
# Case N4: 3-rank tensor (int32), axis=-1
onnx_type = onnx.TensorProto.INT32
x = np.arange(24).reshape(2, 3, 4).astype(np.int32)
axis = -1
input_shape = x.shape  
dim0 = np.prod(input_shape[:axis])    
dim1 = np.prod(input_shape[axis:])
output_shape = [int(dim0), int(dim1)]
session = create_flatten_model(axis, [None,None,None], output_shape, onnx_type)
do_flatten(x, session)

Shape of input tensor: (2, 3, 4)
X=[[[ 0, 1, 2, 3],  [ 4, 5, 6, 7],  [ 8, 9,10,11]], [[12,13,14,15],  [16,17,18,19],  [20,21,22,23]]]
Shape of output tensor: (6, 4)
Result = [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9,10,11], [12,13,14,15], [16,17,18,19], [20,21,22,23]]


In [16]:
# Case N4: 3-rank tensor (int32), axis=-2
onnx_type = onnx.TensorProto.INT32
x = np.arange(24).reshape(2, 3, 4).astype(np.int32)
axis = -2
input_shape = x.shape  
dim0 = np.prod(input_shape[:axis])    
dim1 = np.prod(input_shape[axis:])
output_shape = [int(dim0), int(dim1)]
session = create_flatten_model(axis, [None,None,None], output_shape, onnx_type)
do_flatten(x, session)

Shape of input tensor: (2, 3, 4)
X=[[[ 0, 1, 2, 3],  [ 4, 5, 6, 7],  [ 8, 9,10,11]], [[12,13,14,15],  [16,17,18,19],  [20,21,22,23]]]
Shape of output tensor: (2, 12)
Result = [[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11], [12,13,14,15,16,17,18,19,20,21,22,23]]
