In [6]:
import onnx
import numpy as np
from onnx import helper, numpy_helper, TensorProto, shape_inference

# Paths
input_model = "/home/adelb/Downloads/best-6c-map92.onnx"
output_model = "yolov8_seg_filtered_1class.onnx"

# Load model
model = onnx.load(input_model)
graph = model.graph

# Class slice: [4:10] → 6 classes
starts = numpy_helper.from_array(np.array([0, 0, 4], dtype=np.int64), name="starts")
ends = numpy_helper.from_array(np.array([1, 8400, 10], dtype=np.int64), name="ends")
axes = numpy_helper.from_array(np.array([0, 1, 2], dtype=np.int64), name="axes")
steps = numpy_helper.from_array(np.array([1, 1, 1], dtype=np.int64), name="steps")

graph.initializer.extend([starts, ends, axes, steps])

# Slice class scores
slice_node = helper.make_node(
    "Slice", ["output0", "starts", "ends", "axes", "steps"], ["class_logits"], name="SliceClassScores"
)

# Transpose to [1, 6, 8400]
transpose_node = helper.make_node(
    "Transpose", ["class_logits"], ["transposed_logits"], perm=[0, 2, 1], name="TransposeLogits"
)

# ArgMax over class dim
argmax_node = helper.make_node(
    "ArgMax", ["transposed_logits"], ["predicted_classes"], axis=1, keepdims=0, name="ArgMaxClasses"
)

# Hardcoded IDX to keep (e.g. class 3)
class_id = 1
idx_tensor = numpy_helper.from_array(np.array(class_id, dtype=np.int64), name="IDX")
graph.initializer.append(idx_tensor)

equal_node = helper.make_node(
    "Equal", ["predicted_classes", "IDX"], ["class_match_mask"], name="EqualClass"
)

nonzero_node = helper.make_node(
    "NonZero", ["class_match_mask"], ["match_indices"], name="NonZeroMatch"
)

squeeze_node = helper.make_node(
    "Squeeze", ["match_indices"], ["squeezed_indices"], axes=[0], name="SqueezeIndices"
)

# Gather bbox from output0 (bounding boxes)
gather_node_bbox = helper.make_node(
    "Gather", ["output0", "squeezed_indices"], ["filtered_bbox"], axis=2, name="GatherFilteredBbox"
)

# Gather masks from output1 (mask coefficients)
gather_node_mask = helper.make_node(
    "Gather", ["output1", "squeezed_indices"], ["filtered_masks"], axis=2, name="GatherFilteredMasks"
)

# Concatenate bbox, masks, and class scores into final output
concat_node = helper.make_node(
    "Concat",
    inputs=["filtered_bbox", "filtered_masks", "class_match_mask"],
    outputs=["output0"],  # Rename this as the final output0
    axis=1,
    name="ConcatFinalOutput"
)

# Add nodes to graph
graph.node.extend([
    slice_node, transpose_node, argmax_node,
    equal_node, nonzero_node, squeeze_node,
    gather_node_bbox, gather_node_mask, concat_node
])

# Rename original output (if necessary)
for out in graph.output:
    if out.name == "output0":
        out.name = "output0_original"

# Modify the final output to be 'output0' as requested
graph.output[0].name = "output0"  # Use 'output0' for final result

# Save modified model
inferred = shape_inference.infer_shapes(model)
onnx.save(inferred, output_model)
print(f"Modified model saved to {output_model}")


Modified model saved to yolov8_seg_filtered_1class.onnx


In [8]:
import onnx
from onnx import helper, TensorProto

# Load model
model = onnx.load("yolov8_seg_filtered_1class.onnx")
graph = model.graph

# Clear existing outputs
while len(graph.output):
    graph.output.pop()

# Add new outputs
graph.output.extend([
    helper.make_tensor_value_info("output0", TensorProto.FLOAT, [1, 37, None]),       # Rename of filtered_output0
    helper.make_tensor_value_info("output1", TensorProto.FLOAT, [1, 32, 160, 160])    # Original mask output
])

# Save updated model
onnx.save(model, "your_model_cleaned.onnx")
print("Outputs updated: now only 'output0' and 'output1'.")



Outputs updated: now only 'output0' and 'output1'.


In [9]:
import onnx
from onnx import helper, TensorProto

# Reload after kernel reset
model_path = "/home/adelb/Downloads/best-6c-map92.onnx"
model = onnx.load(model_path)
graph = model.graph

# STEP 1: Rename original output0 to 'output0_original'
for output in graph.output:
    if output.name == "output0":
        output.name = "output0_original"

# STEP 2: Add slicing node to skip bbox info (remove first 4 channels)
# Keep only the 32 mask coefficients + 6 class logits = 38 channels
start_tensor = helper.make_tensor("slice_start", TensorProto.INT64, [1], [4])
end_tensor = helper.make_tensor("slice_end", TensorProto.INT64, [1], [42])
axes_tensor = helper.make_tensor("slice_axes", TensorProto.INT64, [1], [1])
steps_tensor = helper.make_tensor("slice_steps", TensorProto.INT64, [1], [1])
graph.initializer.extend([start_tensor, end_tensor, axes_tensor, steps_tensor])

slice_node = helper.make_node(
    "Slice",
    inputs=["output0_original", "slice_start", "slice_end", "slice_axes", "slice_steps"],
    outputs=["filtered_output0"],
    name="SliceFilteredOutput"
)

# STEP 3: Replace graph outputs
del graph.output[:]
graph.output.extend([
    helper.make_tensor_value_info("output0", TensorProto.FLOAT, [1, 38, None]),  # filtered output
    helper.make_tensor_value_info("output1", TensorProto.FLOAT, [1, 32, 160, 160])  # unchanged
])

# STEP 4: Append the slicing node to the graph
graph.node.append(slice_node)

# Save modified model
filtered_model_path = "filtered.onnx"
onnx.save(model, filtered_model_path)

filtered_model_path

'filtered.onnx'