# Parameters

In [None]:
MODEL_BASE_DIR = "/home/mateo/cancer-ai/manager/models"

# Utils

In [None]:
def get_model_list(folder_path):
    import os

    # List to store ONNX model paths
    onnx_models = []

    # Loop through the folder
    for file_name in os.listdir(folder_path):
        if file_name.endswith(".onnx"):
            onnx_models.append(os.path.join(folder_path, file_name))

    return onnx_models

In [None]:
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights
import onnxruntime as ort
import torch.nn.functional as F
import numpy as np

whitelist1 = [0, 3, 10]  # Example: classes for model1
whitelist2 = [1, 2, 4, 5, 6, 7, 8, 9]


class OnnxCombineModel(nn.Module):
    def __init__(self, model_path1, model_path2):
        super().__init__()

        self.session1 = ort.InferenceSession(model_path1)
        self.session2 = ort.InferenceSession(model_path2)
        self.input_names1 = [inp.name for inp in self.session1.get_inputs()]
        self.input_names2 = [inp.name for inp in self.session2.get_inputs()]

    def forward(self, image, demographics):
        print("|" * 60)
        print(self.input_names1)
        print(self.input_names2)
        inputs1 = {self.input_names1[0]: image, self.input_names1[1]: demographics}
        # inputs2 = {self.input_names2[0]: image} ### 43  84 model ###
        inputs2 = {
            self.input_names2[0]: image,
            self.input_names2[1]: demographics,
        }  ### 108 grose model ###

        outputs1 = self.session1.run(None, inputs1)
        outputs2 = self.session2.run(None, inputs2)

        probs1 = outputs1[0].flatten()
        idx = np.argmax(probs1)

        if idx in whitelist1:
            print('ðŸ”´','model1')
            return F.softmax(torch.tensor(probs1 * 3), dim=0)

        probs2 = outputs2[0].flatten()  ### 43 84 model ###
        idx = np.argmax(probs2)
        if idx in whitelist2:
            print('ðŸ”´','model2')
            return F.softmax(torch.tensor(probs2 * 3), dim=0)

        probs = probs1 * 0.5 + 0.5 * probs2
        probs = probs * 3
        return F.softmax(torch.tensor(probs), dim=0)

Test Pytorch

In [None]:
import numpy as np
from PIL import Image

model = OnnxCombineModel(
    "../../models/2025-11-27/speechmaster/18_model118.onnx",
    "../../models/2025-11-27/speechmaster/62_model94.onnx",
)
device = "cpu"
image = Image.open(
    f"../../dataset/dataset00016/0a605167-4e6e-4104-bc06-1aee2e71b33b.jpg"
).convert("RGB")
image = image.resize((512, 512))
image = np.array(image, dtype=np.float32)
image = image * (1.0 / 255.0)

image = np.transpose(image, (2, 0, 1))
image = torch.from_numpy(image).to(device)
image = image.unsqueeze(0)
# print(image.shape)
data = torch.tensor([30, 0, 6], dtype=torch.float32).unsqueeze(0).to(device)
image = image.numpy()
data = data.numpy()
result = model(image, data)
print(result)

Test Onnx Version

In [None]:
import torchvision.transforms as transforms
import onnxruntime as ort


CLASS_NAMES = [
    "Actinic keratosis (AK)",
    "Basal cell carcinoma (BCC)",
    "Seborrheic keratosis (SK)",
    "Squamous cell carcinoma (SCC)",
    "Vascular lesion (VASC)",
    "Dermatofibroma (DF)",
    "Benign nevus (NV)",
    "Other non-neoplastic (NON)",
    "Melanoma (MEL)",
    "Other neoplastic (ON)",
]


class ONNXInference:
    def __init__(self, model_path):
        """Initialize ONNX model session."""

        self.session = ort.InferenceSession(model_path)
        self.input_names = [inp.name for inp in self.session.get_inputs()]

        # Image preprocessing
        self.transform = transforms.Compose(
            [
                transforms.Resize((512, 512)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def preprocess_image(self, image_path):
        """Load and preprocess image to [0,512] range as specified."""
        img = Image.open(image_path).convert("RGB")
        # Resize to 512x512
        img = img.resize((512, 512))
        # Convert to numpy array with [0,512] range
        img_array = np.array(img, dtype=np.float32)
        # Scale from [0,255] to [0,512]
        img_array = img_array * (1 / 255.0)
        # Convert to BCHW format
        img_array = np.transpose(img_array, (2, 0, 1))
        img_array = np.expand_dims(img_array, axis=0)
        return img_array

    def predict(self, image_path, age, gender, location):
        """Run inference on a single image with demographic data."""
        # Preprocess image
        image_tensor = self.preprocess_image(image_path)

        # Convert demographics to proper format
        # Gender: 'm' -> 1.0, 'f' -> 0.0
        gender_encoded = 1.0 if gender.lower() == "m" else 0.0

        # Prepare demographic data as [age, gender_encoded, location]
        demo_tensor = np.array(
            [[float(age), gender_encoded, float(location)]], dtype=np.float32
        )

        # Run inference
        inputs = {self.input_names[0]: image_tensor, self.input_names[1]: demo_tensor}
        # inputs = {self.input_names[0]: image_tensor}

        # inputs = {self.input_names[0]: image_tensor}
        outputs = self.session.run(None, inputs)
        print(outputs)
        # Model already outputs probabilities (softmax applied in forward pass)
        probs = outputs[0].flatten()

        # Get top 3 predictions
        top3_idx = np.argsort(probs)[-3:][::-1]
        top3 = [(CLASS_NAMES[i], float(probs[i])) for i in top3_idx]

        return top3


print("----------------")
# ort.InferenceSession("onnx/combined-2.onnx")
onnx_model = ONNXInference("../../models/combine/2025-11-27/18vs62_1_down.onnx")
# onnx_model = ONNXInference("model/84.onnx")
predictions = onnx_model.predict(
    f"../../dataset/dataset00016/0a605167-4e6e-4104-bc06-1aee2e71b33b.jpg", 30, "f", 6
)
print(predictions)

## Softmax first model

### Combining 36 vs 43

In [None]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
from onnx import shape_inference
from typing import List

def fix_reduction_nodes(graph: gs.Graph, graph_name: str = "unknown"):
    """
    Fixes ReduceL2 and ReduceMean nodes that incorrectly have axes as input (2 inputs) by moving axes to attribute.
    Searches for the Constant node producing the axes Variable and extracts its value.
    Removes the axes input and the unused Constant node after fix.
    Adds debug prints for all ReduceL2 and ReduceMean nodes.
    """
    fixed_count = 0
    removed_constants = 0
    debug_nodes = []
    for node in graph.nodes:
        if node.op in ['ReduceL2', 'ReduceMean']:
            debug_nodes.append({
                'name': node.name,
                'op': node.op,
                'inputs_count': len(node.inputs),
                'inputs_types': [type(inp).__name__ for inp in node.inputs],
                'second_input_name': node.inputs[1].name if len(node.inputs) > 1 else None
            })
            if len(node.inputs) == 2:
                data_input = node.inputs[0]
                axes_var = node.inputs[1]
                # Search for Constant node producing axes_var
                constant_node = None
                axes_values = None
                for c_node in graph.nodes:
                    if (c_node.op == 'Constant' and 
                        c_node.outputs and len(c_node.outputs) == 1 and 
                        c_node.outputs[0].name == axes_var.name):
                        constant_node = c_node
                        if 'value' in c_node.attrs:
                            axes_values = c_node.attrs['value'].values
                            if isinstance(axes_values, np.ndarray):
                                axes_values = axes_values.tolist()
                        break
                if constant_node and axes_values is not None:
                    # Update node: remove second input, add axes attr
                    node.inputs = [data_input]
                    node.attrs['axes'] = axes_values
                    # Ensure keepdims is set (default 1 for most reductions)
                    if 'keepdims' not in node.attrs:
                        node.attrs['keepdims'] = 1
                    fixed_count += 1
                    print(f"[{graph_name}] Fixed {node.op} node '{node.name}': axes {axes_values} extracted from Constant '{constant_node.name}'")
                    # Mark for removal; cleanup will handle unused nodes
                    removed_constants += 1
                else:
                    print(f"[{graph_name}] Warning: Could not find/extract axes for {node.op} '{node.name}'; second input '{axes_var.name}', Constant found: {constant_node is not None}")
    if debug_nodes:
        print(f"[{graph_name}] Total {', '.join(set(dn['op'] for dn in debug_nodes))} nodes: {len(debug_nodes)}, Fixed: {fixed_count}")
        for dn in debug_nodes[:3]:  # Print first 3 for brevity
            print(f"  - {dn['name']}: {dn['op']}, {dn['inputs_count']} inputs, types: {dn['inputs_types']}, second_name: {dn['second_input_name']}")
        if len(debug_nodes) > 3:
            print(f"  ... and {len(debug_nodes)-3} more")
    return fixed_count

def create_combined_onnx(model_path1, model_path2, output_path='combined.onnx'):
    """
    Combines two ONNX models into one:
    - Model1: takes 'image' and 'demographics' -> logits1
    - Model2: takes 'image' -> logits2
    - Combined: takes 'image' and 'demographics' -> (softmax(logits1) + logits2) / 2
    
    Note: This averages probabilities from Model1 with raw logits from Model2, which may not be semantically ideal
    due to scale differences (probabilities in [0,1], logits unbounded). Consider if softmax should also be applied
    to Model2 or if the final average should be softmaxed.
    
    Assumes:
    - Both models output a single tensor of shape [batch_size, num_classes] (logits).
    - Input names: Model1 has two inputs (first: image, second: demographics); Model2 has one (image).
    - You need to pip install onnx onnx-graphsurgeon if not already installed.
    """
    # Load the models
    onnx_model1 = onnx.load(model_path1)
    onnx_model2 = onnx.load(model_path2)
    
    # Import into graph surgeon
    graph1 = gs.import_onnx(onnx_model1)
    graph2 = gs.import_onnx(onnx_model2)
    
    # Fix reduction nodes in BOTH graphs for thoroughness
    total_fixed = 0
    total_fixed += fix_reduction_nodes(graph1, "Model1")
    total_fixed += fix_reduction_nodes(graph2, "Model2")
    if total_fixed == 0:
        print("No reduction fixes applied - check debug output above")
    
    # Assume input names and order based on your code
    # Rename for clarity and sharing
    image_input = graph1.inputs[0]
    image_input.name = 'image'
    
    demographics_input = graph1.inputs[1]
    demographics_input.name = 'demographics'
    
    # Share the image input with model2
    old_image_input = graph2.inputs[0]
    shared_image_input = image_input  # Shared reference, named 'image'
    
    # Replace all references in graph2 nodes from old_image_input to shared_image_input
    for node in graph2.nodes:
        for i in range(len(node.inputs)):
            if node.inputs[i] is old_image_input:
                node.inputs[i] = shared_image_input
    
    # Update graph2's inputs list to use the shared input
    graph2.inputs[0] = shared_image_input
    
    # Get outputs (assume single output each)
    logits1 = graph1.outputs[0]
    logits1.name = 'logits1'
    
    logits2 = graph2.outputs[0]
    logits2.name = 'logits2'
    
    # Extract num_classes from logits1 shape (assume [batch, num_classes]; batch dynamic)
    orig_shape = logits1.shape
    if orig_shape and len(orig_shape) >= 2:
        num_classes = orig_shape[-1]
        if num_classes == 0 or num_classes is None:
            num_classes = 10  # Fallback assumption based on reported output size
        output_shape = [None, num_classes]  # Dynamic batch
    else:
        output_shape = [None, 10]  # Fallback
        num_classes = 10
        print(f"Warning: Could not infer num_classes from shape {orig_shape}; using fallback [None, 10]")
    
    print(f"Inferred output shape: {output_shape}")
    
    # Define output variables WITH dtype and shape (no flattening)
    probs1 = gs.Variable('probs1', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    sum_avg = gs.Variable('sum_avg', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    avg_output = gs.Variable('avg_output', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    avg_output1 = gs.Variable('avg_output1', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    avg_output2 = gs.Variable('avg_output2', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    
    # Softmax on first model (axis=1 for [batch, classes])
    softmax1 = gs.Node(
        op='Softmax',
        inputs=[logits1],
        outputs=[probs1],
        attrs={'axis': 1}
    )
    
    # Average: (probs1 + logits2) / 2

    
    constant_07 = gs.Constant(name='constant_07', values=np.array(0.4, dtype=np.float32))  # Scalar for broadcast
    constant_03 = gs.Constant(name='constant_03', values=np.array(0.6, dtype=np.float32))  # Scalar for broadcast
    mul1 = gs.Node(
        op='Mul',  # Equivalent to / 2
        inputs=[probs1, constant_07],
        outputs=[avg_output1]
    )
    mul2 = gs.Node(
        op='Mul',  # Equivalent to / 2
        inputs=[logits2, constant_03],
        outputs=[avg_output2]
    )
    
    add = gs.Node(
        op='Add',
        inputs=[avg_output1, avg_output2],
        outputs=[sum_avg]
    )
    # Combined graph: nodes from both + new nodes; inputs: image + demographics; output: avg_output
    combined_graph = gs.Graph(
        nodes=graph1.nodes + graph2.nodes + [softmax1, mul1, mul2, add],
        inputs=[shared_image_input, demographics_input],
        outputs=[sum_avg]
    )
    
    # Set opset on the graph for LayerNormalization support (opset 17+)
    combined_graph.opset = 17
    
    # Cleanup and export
    combined_model = gs.export_onnx(combined_graph.cleanup())
    
    # Infer shapes to fill in any missing (helps checker)
    combined_model = shape_inference.infer_shapes(combined_model)
    
    # Optional: Check model
    onnx.checker.check_model(combined_model)
    
    # Save
    onnx.save(combined_model, output_path)
    print(f"Combined ONNX model saved to {output_path}")
    print(f"Output shape: {output_shape}")
    
    return combined_model

# Usage
combined = create_combined_onnx('model/36.onnx', 'model/43_modelvip.onnx', "model/test(7.3).onnx")

### Combining 36 vs 84

In [None]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
from onnx import shape_inference
from typing import List


def fix_reduction_nodes(graph: gs.Graph, graph_name: str = "unknown"):
    """
    Fixes ReduceL2 and ReduceMean nodes that incorrectly have axes as input (2 inputs) by moving axes to attribute.
    Searches for the Constant node producing the axes Variable and extracts its value.
    Removes the axes input and the unused Constant node after fix.
    Adds debug prints for all ReduceL2 and ReduceMean nodes.
    """
    fixed_count = 0
    removed_constants = 0
    debug_nodes = []
    for node in graph.nodes:
        if node.op in ['ReduceL2', 'ReduceMean']:
            debug_nodes.append({
                'name': node.name,
                'op': node.op,
                'inputs_count': len(node.inputs),
                'inputs_types': [type(inp).__name__ for inp in node.inputs],
                'second_input_name': node.inputs[1].name if len(node.inputs) > 1 else None
            })
            if len(node.inputs) == 2:
                data_input = node.inputs[0]
                axes_var = node.inputs[1]
                # Search for Constant node producing axes_var
                constant_node = None
                axes_values = None
                for c_node in graph.nodes:
                    if (c_node.op == 'Constant' and 
                        c_node.outputs and len(c_node.outputs) == 1 and 
                        c_node.outputs[0].name == axes_var.name):
                        constant_node = c_node
                        if 'value' in c_node.attrs:
                            axes_values = c_node.attrs['value'].values
                            if isinstance(axes_values, np.ndarray):
                                axes_values = axes_values.tolist()
                        break
                if constant_node and axes_values is not None:
                    # Update node: remove second input, add axes attr
                    node.inputs = [data_input]
                    node.attrs['axes'] = axes_values
                    # Ensure keepdims is set (default 1 for most reductions)
                    if 'keepdims' not in node.attrs:
                        node.attrs['keepdims'] = 1
                    fixed_count += 1
                    print(f"[{graph_name}] Fixed {node.op} node '{node.name}': axes {axes_values} extracted from Constant '{constant_node.name}'")
                    # Mark for removal; cleanup will handle unused nodes
                    removed_constants += 1
                else:
                    print(f"[{graph_name}] Warning: Could not find/extract axes for {node.op} '{node.name}'; second input '{axes_var.name}', Constant found: {constant_node is not None}")
    if debug_nodes:
        print(f"[{graph_name}] Total {', '.join(set(dn['op'] for dn in debug_nodes))} nodes: {len(debug_nodes)}, Fixed: {fixed_count}")
        for dn in debug_nodes[:3]:  # Print first 3 for brevity
            print(f"  - {dn['name']}: {dn['op']}, {dn['inputs_count']} inputs, types: {dn['inputs_types']}, second_name: {dn['second_input_name']}")
        if len(debug_nodes) > 3:
            print(f"  ... and {len(debug_nodes)-3} more")
    return fixed_count


def _rename_graph_tensors_and_nodes(graph: gs.Graph, prefix: str, skip_vars: List[gs.Variable] = None):
    """Prefix all tensor and node names in `graph` with `prefix`, except variables in skip_vars.

    This avoids name collisions when combining multiple graphs. We compare skip_vars by object id to
    ensure we don't rename the shared input Variable object.
    """
    if skip_vars is None:
        skip_vars = []
    skip_ids = {id(v) for v in skip_vars}

    # Rename variables (tensors)
    tensors = list(graph.tensors().values())
    for var in tensors:
        if id(var) in skip_ids:
            continue
        if var.name:
            var.name = prefix + var.name

    # Rename nodes
    for node in graph.nodes:
        if node.name:
            node.name = prefix + node.name


def create_combined_onnx(model_path1, model_path2, output_path='combined.onnx'):
    """
    Combines two ONNX models into one:
    - Model1: takes 'image' and 'demographics' -> logits1
    - Model2: takes 'image' -> logits2
    - Combined: takes 'image' and 'demographics' -> (softmax(logits1) + logits2) / 2

    Key changes vs. earlier: we rename the second graph's tensors/nodes with a prefix to avoid name collisions
    and ensure the shared `image` input variable object is used by both graphs. This prevents duplicate tensor
    names and topological ordering issues during checker validation.
    """
    # Load the models
    onnx_model1 = onnx.load(model_path1)
    onnx_model2 = onnx.load(model_path2)

    # Import into graph surgeon
    graph1 = gs.import_onnx(onnx_model1)
    graph2 = gs.import_onnx(onnx_model2)

    # Fix reduction nodes in BOTH graphs for thoroughness
    total_fixed = 0
    total_fixed += fix_reduction_nodes(graph1, "Model1")
    total_fixed += fix_reduction_nodes(graph2, "Model2")
    if total_fixed == 0:
        print("No reduction fixes applied - check debug output above")

    # Rename for clarity and sharing
    image_input = graph1.inputs[0]
    image_input.name = 'image'

    demographics_input = graph1.inputs[1]
    demographics_input.name = 'demographics'

    # Grab model2's image input object BEFORE renaming so we can skip renaming that specific Variable
    old_image_input = graph2.inputs[0]

    # Rename graph2 tensors/nodes to avoid clashes (but don't rename the image Variable object)
    _rename_graph_tensors_and_nodes(graph2, prefix='g2_', skip_vars=[old_image_input])

    # Replace all references in graph2 nodes from old_image_input to the shared image_input object
    for node in graph2.nodes:
        for i in range(len(node.inputs)):
            if node.inputs[i] is old_image_input:
                node.inputs[i] = image_input

    # Update graph2's inputs list to use the shared input object (this removes a duplicate input with same name)
    graph2.inputs[0] = image_input

    # Get outputs (assume single output each)
    logits1 = graph1.outputs[0]
    logits1.name = 'logits1'

    logits2 = graph2.outputs[0]
    logits2.name = 'logits2'

    # Extract num_classes from logits1 shape (assume [batch, num_classes]; batch dynamic)
    orig_shape = logits1.shape
    if orig_shape and len(orig_shape) >= 2:
        num_classes = orig_shape[-1]
        if num_classes == 0 or num_classes is None:
            num_classes = 10  # Fallback assumption based on reported output size
        output_shape = [None, num_classes]  # Dynamic batch
    else:
        output_shape = [None, 10]  # Fallback
        num_classes = 10
        print(f"Warning: Could not infer num_classes from shape {orig_shape}; using fallback [None, 10]")

    print(f"Inferred output shape: {output_shape}")

    # Define output variables WITH dtype and shape (no flattening)
    probs1 = gs.Variable('probs1', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    probs2 = gs.Variable('probs2', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    sum_avg = gs.Variable('sum_avg', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    avg_output1 = gs.Variable('avg_output1', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    avg_output2 = gs.Variable('avg_output2', shape=output_shape, dtype=onnx.TensorProto.FLOAT)

    # Softmax on first model (axis=1 for [batch, classes])
    softmax1 = gs.Node(
        op='Softmax',
        inputs=[logits1],
        outputs=[probs1],
        attrs={'axis': 1}
    )    
    softmax2 = gs.Node(
        op='Softmax',
        inputs=[logits2],
        outputs=[probs2],
        attrs={'axis': 1}
    )

    # Average: (probs1 + logits2) / 2

    constant_07 = gs.Constant(name='constant_07', values=np.array(0.7, dtype=np.float32))  # Scalar for broadcast
    constant_03 = gs.Constant(name='constant_03', values=np.array(0.3, dtype=np.float32))  # Scalar for broadcast
    mul1 = gs.Node(
        op='Mul',  # Equivalent to / 2
        inputs=[probs1, constant_03],
        outputs=[avg_output1]
    )
    mul2 = gs.Node(
        op='Mul',  # Equivalent to / 2
        inputs=[probs2, constant_07],
        outputs=[avg_output2]
    )
    add = gs.Node(
        op='Add',
        inputs=[avg_output1, avg_output2],
        outputs=[sum_avg]
    )

    # Combined graph: nodes from both + new nodes; inputs: image + demographics; output: avg_output
    # We put graph1 nodes first, then graph2 nodes (which we've namespaced) so producers appear before consumers.
    combined_graph = gs.Graph(
        nodes=graph1.nodes + graph2.nodes + [softmax1, softmax2, mul1, mul2, add],
        inputs=[image_input, demographics_input],
        outputs=[sum_avg]
    )

    # Set opset on the graph for LayerNormalization support (opset 17+)
    combined_graph.opset = 17

    # Cleanup and export - cleanup will remove unused nodes and should also fix ordering where possible
    combined_model = gs.export_onnx(combined_graph.cleanup())

    # Infer shapes to fill in any missing (helps checker)
    combined_model = shape_inference.infer_shapes(combined_model)

    # Optional: Check model
    onnx.checker.check_model(combined_model)

    # Save
    onnx.save(combined_model, output_path)
    print(f"Combined ONNX model saved to {output_path}")
    print(f"Output shape: {output_shape}")

    return combined_model


# Usage
# Note: adjust paths as needed
combined = create_combined_onnx('model/36.onnx', 'model/84.onnx', "model/softmax_36_84(3.7).onnx")


### Combining 36 vs 108

In [None]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
import uuid
from onnx import shape_inference
from typing import List


def fix_reduction_nodes(graph: gs.Graph, graph_name: str = "unknown"):
    """
    Fixes ReduceL2 and ReduceMean nodes that incorrectly have axes as input (2 inputs) by moving axes to attribute.
    Searches for the Constant node producing the axes Variable and extracts its value.
    Removes the axes input and the unused Constant node after fix.
    Adds debug prints for all ReduceL2 and ReduceMean nodes.
    """
    fixed_count = 0
    removed_constants = 0
    debug_nodes = []
    for node in graph.nodes:
        if node.op in ['ReduceL2', 'ReduceMean']:
            debug_nodes.append({
                'name': node.name,
                'op': node.op,
                'inputs_count': len(node.inputs),
                'inputs_types': [type(inp).__name__ for inp in node.inputs],
                'second_input_name': node.inputs[1].name if len(node.inputs) > 1 else None
            })
            if len(node.inputs) == 2:
                data_input = node.inputs[0]
                axes_var = node.inputs[1]
                # Search for Constant node producing axes_var
                constant_node = None
                axes_values = None
                for c_node in graph.nodes:
                    if (c_node.op == 'Constant' and 
                        c_node.outputs and len(c_node.outputs) == 1 and 
                        c_node.outputs[0].name == axes_var.name):
                        constant_node = c_node
                        if 'value' in c_node.attrs:
                            axes_values = c_node.attrs['value'].values
                            if isinstance(axes_values, np.ndarray):
                                axes_values = axes_values.tolist()
                        break
                if constant_node and axes_values is not None:
                    # Update node: remove second input, add axes attr
                    node.inputs = [data_input]
                    node.attrs['axes'] = axes_values
                    # Ensure keepdims is set (default 1 for most reductions)
                    if 'keepdims' not in node.attrs:
                        node.attrs['keepdims'] = 1
                    fixed_count += 1
                    print(f"[{graph_name}] Fixed {node.op} node '{node.name}': axes {axes_values} extracted from Constant '{constant_node.name}'")
                    # Mark for removal; cleanup will handle unused nodes
                    removed_constants += 1
                else:
                    print(f"[{graph_name}] Warning: Could not find/extract axes for {node.op} '{node.name}'; second input '{axes_var.name}', Constant found: {constant_node is not None}")
    if debug_nodes:
        print(f"[{graph_name}] Total {', '.join(set(dn['op'] for dn in debug_nodes))} nodes: {len(debug_nodes)}, Fixed: {fixed_count}")
        for dn in debug_nodes[:3]:  # Print first 3 for brevity
            print(f"  - {dn['name']}: {dn['op']}, {dn['inputs_count']} inputs, types: {dn['inputs_types']}, second_name: {dn['second_input_name']}")
        if len(debug_nodes) > 3:
            print(f"  ... and {len(debug_nodes)-3} more")
    return fixed_count


def _rename_graph_tensors_and_nodes(graph: gs.Graph, prefix: str, skip_vars: List[gs.Variable] = None):
    """Prefix all tensor and node names in `graph` with `prefix`, except variables in skip_vars.

    This avoids name collisions when combining multiple graphs. We compare skip_vars by object id to
    ensure we don't rename shared input Variable objects.
    """
    if skip_vars is None:
        skip_vars = []
    skip_ids = {id(v) for v in skip_vars}

    # Rename variables (tensors)
    for var in list(graph.tensors().values()):
        # skip renaming the exact variable objects that are shared
        if id(var) in skip_ids:
            continue
        if var.name:
            var.name = prefix + var.name

    # Rename nodes
    for node in graph.nodes:
        if node.name:
            node.name = prefix + node.name


def create_combined_onnx(model_path1, model_path2, output_path='combined.onnx'):
    """
    Combines two ONNX models into one:
    - Model1: takes 'image' and 'demographics' -> logits1
    - Model2: takes 'image' -> logits2
    - Combined: takes 'image' and 'demographics' -> average(softmax(logits1), softmax(logits2))

    Approach:
    - Import both graphs with onnx-graphsurgeon
    - Capture model2's original input Variable objects (so we can find & replace them)
    - Namespace (prefix) all graph2 tensors/nodes to avoid collisions, except the original model2 input Variable objects
    - Replace model2's input references with the shared input Variable objects from graph1
    - Build combined graph with graph1 nodes first, then graph2 nodes, then the new ops
    - Cleanup, infer shapes, and run checker
    """
    # Load the models
    onnx_model1 = onnx.load(model_path1)
    onnx_model2 = onnx.load(model_path2)

    # Import into graph surgeon
    graph1 = gs.import_onnx(onnx_model1)
    graph2 = gs.import_onnx(onnx_model2)

    # Fix reduction nodes in BOTH graphs for thoroughness
    total_fixed = 0
    total_fixed += fix_reduction_nodes(graph1, "Model1")
    total_fixed += fix_reduction_nodes(graph2, "Model2")
    if total_fixed == 0:
        print("No reduction fixes applied - check debug output above")

    # Prepare shared inputs from graph1
    image_input = graph1.inputs[0]
    image_input.name = 'image'

    demographics_input = graph1.inputs[1]
    demographics_input.name = 'demographics'

    # Save model2's original input variable objects so we can target them for replacement
    model2_image_var = graph2.inputs[0]
    model2_demo_var = None
    if len(graph2.inputs) > 1:
        model2_demo_var = graph2.inputs[1]

    # Namespace graph2 to avoid collisions but skip the original input objects
    prefix = 'g2_'
    _rename_graph_tensors_and_nodes(graph2, prefix=prefix, skip_vars=[model2_image_var] + ([model2_demo_var] if model2_demo_var is not None else []))

    # Replace references in graph2 nodes from the original model2 input objects to the shared ones
    for node in graph2.nodes:
        for i, inp in enumerate(node.inputs):
            if inp is model2_image_var:
                node.inputs[i] = image_input
            elif model2_demo_var is not None and inp is model2_demo_var:
                node.inputs[i] = demographics_input

    # Now update graph2's inputs list to use the shared input objects
    graph2.inputs[0] = image_input
    if model2_demo_var is not None:
        # If model2 had a demographics input, map it to the shared demographics
        graph2.inputs[1] = demographics_input

    # Get outputs (assume single output each)
    logits1 = graph1.outputs[0]
    logits1.name = 'logits1'

    logits2 = graph2.outputs[0]
    # The logits2 variable object may have been renamed (prefixed), ensure we use the object itself
    logits2.name = 'logits2'

    # Extract num_classes from logits1 shape (assume [batch, num_classes]; batch dynamic)
    orig_shape = logits1.shape
    if orig_shape and len(orig_shape) >= 2:
        num_classes = orig_shape[-1]
        if num_classes == 0 or num_classes is None:
            num_classes = 10  # Fallback assumption based on reported output size
        output_shape = [None, num_classes]  # Dynamic batch
    else:
        output_shape = [None, 10]  # Fallback
        num_classes = 10
        print(f"Warning: Could not infer num_classes from shape {orig_shape}; using fallback [None, 10]")

    print(f"Inferred output shape: {output_shape}")

    # Create variables for intermediate and final tensors
    probs1 = gs.Variable('probs1', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    probs2 = gs.Variable('probs2', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    avg1 = gs.Variable('avg1', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    avg2 = gs.Variable('avg2', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    final_avg = gs.Variable('final_avg', shape=output_shape, dtype=onnx.TensorProto.FLOAT)

    # Softmax on both logits (axis=1)
    softmax1 = gs.Node(op='Softmax', inputs=[logits1], outputs=[probs1], attrs={'axis': 1})
    softmax2 = gs.Node(op='Softmax', inputs=[logits2], outputs=[probs2], attrs={'axis': 1})

    # Multiply each probability vector by 0.5 (use uniquely named constant to avoid duplication)
    constant_07 = gs.Constant(name='constant_07', values=np.array(0.7, dtype=np.float32))  # Scalar for broadcast
    constant_03 = gs.Constant(name='constant_03', values=np.array(0.3, dtype=np.float32))  # Scalar for broadcast

    mul1 = gs.Node(op='Mul', inputs=[probs1, constant_07], outputs=[avg1])
    mul2 = gs.Node(op='Mul', inputs=[probs2, constant_03], outputs=[avg2])

    # Add the two halves to get the average
    add = gs.Node(op='Add', inputs=[avg1, avg2], outputs=[final_avg])

    # Build combined graph
    # Place graph1 nodes first, then graph2 nodes (namespaced), then our fusion nodes
    combined_nodes = list(graph1.nodes) + list(graph2.nodes) + [softmax1, softmax2, mul1, mul2, add]

    combined_graph = gs.Graph(nodes=combined_nodes, inputs=[image_input, demographics_input], outputs=[final_avg])

    # Set a reasonable opset
    combined_graph.opset = max(getattr(graph1, 'opset', 11), getattr(graph2, 'opset', 11), 11)

    # Cleanup and export
    combined_model = gs.export_onnx(combined_graph.cleanup())

    # Infer shapes
    combined_model = shape_inference.infer_shapes(combined_model)

    # Validate
    onnx.checker.check_model(combined_model)

    # Save
    onnx.save(combined_model, output_path)
    print(f"Combined ONNX model saved to {output_path}")
    print(f"Output shape: {output_shape}")

    return combined_model


# Usage example (adjust paths as needed)
combined = create_combined_onnx('model/36.onnx', 'model/108_grose.onnx', "model/softmax_36_108(7.3).onnx")


### Combining 18 vs 62

#### Strategy 1

In [56]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
from onnx import shape_inference
from typing import List


def fix_reduction_nodes(graph: gs.Graph, graph_name: str = "unknown"):
    """
    Fixes ReduceL2 and ReduceMean nodes that incorrectly have axes as input (2 inputs) by moving axes to attribute.
    Searches for the Constant node producing the axes Variable and extracts its value.
    Removes the axes input and the unused Constant node after fix.
    Adds debug prints for all ReduceL2 and ReduceMean nodes.
    """
    fixed_count = 0
    removed_constants = 0
    debug_nodes = []
    for node in graph.nodes:
        if node.op in ["ReduceL2", "ReduceMean"]:
            debug_nodes.append(
                {
                    "name": node.name,
                    "op": node.op,
                    "inputs_count": len(node.inputs),
                    "inputs_types": [type(inp).__name__ for inp in node.inputs],
                    "second_input_name": (
                        node.inputs[1].name if len(node.inputs) > 1 else None
                    ),
                }
            )
            if len(node.inputs) == 2:
                data_input = node.inputs[0]
                axes_var = node.inputs[1]
                # Search for Constant node producing axes_var
                constant_node = None
                axes_values = None
                for c_node in graph.nodes:
                    if (
                        c_node.op == "Constant"
                        and c_node.outputs
                        and len(c_node.outputs) == 1
                        and c_node.outputs[0].name == axes_var.name
                    ):
                        constant_node = c_node
                        if "value" in c_node.attrs:
                            axes_values = c_node.attrs["value"].values
                            if isinstance(axes_values, np.ndarray):
                                axes_values = axes_values.tolist()
                        break
                if constant_node and axes_values is not None:
                    # Update node: remove second input, add axes attr
                    node.inputs = [data_input]
                    node.attrs["axes"] = axes_values
                    # Ensure keepdims is set (default 1 for most reductions)
                    if "keepdims" not in node.attrs:
                        node.attrs["keepdims"] = 1
                    fixed_count += 1
                    print(
                        f"[{graph_name}] Fixed {node.op} node '{node.name}': axes {axes_values} extracted from Constant '{constant_node.name}'"
                    )
                    # Mark for removal; cleanup will handle unused nodes
                    removed_constants += 1
                else:
                    print(
                        f"[{graph_name}] Warning: Could not find/extract axes for {node.op} '{node.name}'; second input '{axes_var.name}', Constant found: {constant_node is not None}"
                    )
    if debug_nodes:
        print(
            f"[{graph_name}] Total {', '.join(set(dn['op'] for dn in debug_nodes))} nodes: {len(debug_nodes)}, Fixed: {fixed_count}"
        )
        for dn in debug_nodes[:3]:  # Print first 3 for brevity
            print(
                f"  - {dn['name']}: {dn['op']}, {dn['inputs_count']} inputs, types: {dn['inputs_types']}, second_name: {dn['second_input_name']}"
            )
        if len(debug_nodes) > 3:
            print(f"  ... and {len(debug_nodes)-3} more")
    return fixed_count


def _rename_graph_tensors_and_nodes(
    graph: gs.Graph, prefix: str, skip_vars: List[gs.Variable] = None
):
    """Prefix all tensor and node names in `graph` with `prefix`, except variables in skip_vars.

    This avoids name collisions when combining multiple graphs. We compare skip_vars by object id to
    ensure we don't rename the shared input Variable object.
    """
    if skip_vars is None:
        skip_vars = []
    skip_ids = {id(v) for v in skip_vars}

    # Rename variables (tensors)
    tensors = list(graph.tensors().values())
    for var in tensors:
        if id(var) in skip_ids:
            continue
        if var.name:
            var.name = prefix + var.name

    # Rename nodes
    for node in graph.nodes:
        if node.name:
            node.name = prefix + node.name


def create_combined_onnx(
    model_path1,
    model_path2,
    whitelist1: List[int],
    whitelist2: List[int],
    scale_const: float = 3.0,
    output_path="combined.onnx",
):
    """
    Combines two ONNX models into one with conditional logic based on whitelists for skin cancer strategy:
    - Compute logits1 from model1.
    - If argmax(logits1) in whitelist1, use logits1.
    - Else compute logits2 from model2.
    - If argmax(logits2) in whitelist2, use logits2.
    - Else use average of logits1 and logits2.
    - Then scale the chosen logits by scale_const and apply softmax for output.

    Assumes:
    - Both models take 'image' and 'demographics' as inputs.
    - Both output a single tensor of shape [batch_size, num_classes] (logits).
    - Whitelists are lists of class indices (integers).
    - You need to pip install onnx onnx-graphsurgeon if not already installed.
    """
    # Load the models
    onnx_model1 = onnx.load(model_path1)
    onnx_model2 = onnx.load(model_path2)

    # Import into graph surgeon
    graph1 = gs.import_onnx(onnx_model1)
    graph2 = gs.import_onnx(onnx_model2)

    # Fix reduction nodes in BOTH graphs for thoroughness
    total_fixed = 0
    total_fixed += fix_reduction_nodes(graph1, "Model1")
    total_fixed += fix_reduction_nodes(graph2, "Model2")
    if total_fixed == 0:
        print("No reduction fixes applied - check debug output above")

    # Assume input names and order based on your code
    # Rename for clarity
    image_input = graph1.inputs[0]
    image_input.name = "image"

    demographics_input = graph1.inputs[1]
    demographics_input.name = "demographics"

    # Share the inputs with model2
    old_image_input = graph2.inputs[0]
    old_demo_input = graph2.inputs[1]

    _rename_graph_tensors_and_nodes(
        graph2, prefix="g2_", skip_vars=[old_image_input, old_demo_input]
    )

    # Replace all references in graph2 nodes to use shared inputs
    for node in graph2.nodes:
        for i in range(len(node.inputs)):
            if node.inputs[i] is old_image_input:
                node.inputs[i] = image_input
            if node.inputs[i] is old_demo_input:
                node.inputs[i] = demographics_input

    # Update graph2's inputs list to use the shared inputs
    graph2.inputs[0] = image_input
    graph2.inputs[1] = demographics_input

    # Get outputs (assume single output each)
    logits1 = graph1.outputs[0]
    logits1.name = "logits1"

    logits2 = graph2.outputs[0]
    logits2.name = "logits2"

    # Extract num_classes from logits1 shape (assume [batch, num_classes]; batch dynamic)
    orig_shape = logits1.shape
    if orig_shape and len(orig_shape) >= 2:
        num_classes = orig_shape[-1]
        if num_classes == 0 or num_classes is None:
            num_classes = 11  # Fallback assumption
        output_shape = [None, num_classes]  # Dynamic batch
    else:
        output_shape = [None, 11]  # Fallback
        num_classes = 11
        print(
            f"Warning: Could not infer num_classes from shape {orig_shape}; using fallback [None, 11]"
        )

    print(f"Inferred output shape: {output_shape}")

    # Define variables
    class1 = gs.Variable("class1", shape=[None], dtype=np.int64)
    class1_unsq = gs.Variable("class1_unsq", shape=[None, 1], dtype=np.int64)
    eq1 = gs.Variable("eq1", shape=[None, len(whitelist1)], dtype=np.bool)
    cast1 = gs.Variable("cast1", shape=[None, len(whitelist1)], dtype=np.float32)
    reduce1 = gs.Variable("reduce1", shape=[None, 1], dtype=np.float32)
    is_in1 = gs.Variable("is_in1", shape=[None, 1], dtype=np.bool)
    is_in1_exp = gs.Variable("is_in1_exp", shape=output_shape, dtype=np.bool)

    class2 = gs.Variable("class2", shape=[None], dtype=np.int64)
    class2_unsq = gs.Variable("class2_unsq", shape=[None, 1], dtype=np.int64)
    eq2 = gs.Variable("eq2", shape=[None, len(whitelist2)], dtype=np.bool)
    cast2 = gs.Variable("cast2", shape=[None, len(whitelist2)], dtype=np.float32)
    reduce2 = gs.Variable("reduce2", shape=[None, 1], dtype=np.float32)
    is_in2 = gs.Variable("is_in2", shape=[None, 1], dtype=np.bool)
    is_in2_exp = gs.Variable("is_in2_exp", shape=output_shape, dtype=np.bool)

    logits_shape_var = gs.Variable("logits_shape", dtype=np.int64, shape=[2])

    add_output = gs.Variable("add_output", shape=output_shape, dtype=np.float32)
    avg_logits = gs.Variable("avg_logits", shape=output_shape, dtype=np.float32)
    inner_selected = gs.Variable("inner_selected", shape=output_shape, dtype=np.float32)
    selected_logits = gs.Variable(
        "selected_logits", shape=output_shape, dtype=np.float32
    )
    scaled_logits = gs.Variable("scaled_logits", shape=output_shape, dtype=np.float32)
    final_output = gs.Variable("final_output", shape=output_shape, dtype=np.float32)

    # Constants
    whitelist1_const = gs.Constant(
        "whitelist1", values=np.array(whitelist1, dtype=np.int64)
    )
    whitelist2_const = gs.Constant(
        "whitelist2", values=np.array(whitelist2, dtype=np.int64)
    )
    zero_const = gs.Constant("zero", values=np.array(0.0, dtype=np.float32))
    two_const = gs.Constant("two", values=np.array(2.0, dtype=np.float32))
    scale_const_node = gs.Constant(
        "scale", values=np.array(scale_const, dtype=np.float32)
    )
    axes_unsq = gs.Constant("axes_unsq", values=np.array([1], dtype=np.int64))
    axes_reduce = gs.Constant("axes_reduce", values=np.array([1], dtype=np.int64))

    # Nodes for whitelist1 check
    argmax1 = gs.Node(
        op="ArgMax", inputs=[logits1], outputs=[class1], attrs={"axis": 1, "keepdims": 0}
    )
    unsqueeze1 = gs.Node(
        op="Unsqueeze", inputs=[class1, axes_unsq], outputs=[class1_unsq]
    )
    equal1 = gs.Node(op="Equal", inputs=[class1_unsq, whitelist1_const], outputs=[eq1])
    cast1_node = gs.Node(
        op="Cast", inputs=[eq1], outputs=[cast1], attrs={"to": onnx.TensorProto.FLOAT}
    )
    reducesum1 = gs.Node(
        op="ReduceSum",
        inputs=[cast1, axes_reduce],
        outputs=[reduce1],
        attrs={"keepdims": 1},
    )
    greater1 = gs.Node(op="Greater", inputs=[reduce1, zero_const], outputs=[is_in1])

    # Nodes for whitelist2 check
    argmax2 = gs.Node(
        op="ArgMax", inputs=[logits2], outputs=[class2], attrs={"axis": 1, "keepdims": 0}
    )
    unsqueeze2 = gs.Node(
        op="Unsqueeze", inputs=[class2, axes_unsq], outputs=[class2_unsq]
    )
    equal2 = gs.Node(op="Equal", inputs=[class2_unsq, whitelist2_const], outputs=[eq2])
    cast2_node = gs.Node(
        op="Cast", inputs=[eq2], outputs=[cast2], attrs={"to": onnx.TensorProto.FLOAT}
    )
    reducesum2 = gs.Node(
        op="ReduceSum",
        inputs=[cast2, axes_reduce],
        outputs=[reduce2],
        attrs={"keepdims": 1},
    )
    greater2 = gs.Node(op="Greater", inputs=[reduce2, zero_const], outputs=[is_in2])

    # Dynamic shape for expand
    shape_node = gs.Node(op="Shape", inputs=[logits1], outputs=[logits_shape_var])

    expand1 = gs.Node(
        op="Expand", inputs=[is_in1, logits_shape_var], outputs=[is_in1_exp]
    )
    expand2 = gs.Node(
        op="Expand", inputs=[is_in2, logits_shape_var], outputs=[is_in2_exp]
    )

    # Combine logic
    add_logits = gs.Node(op="Add", inputs=[logits1, logits2], outputs=[add_output])
    avg_node = gs.Node(op="Div", inputs=[add_output, two_const], outputs=[avg_logits])
    inner_where = gs.Node(
        op="Where",
        inputs=[is_in2_exp, logits2, avg_logits],
        outputs=[inner_selected],
    )
    outer_where = gs.Node(
        op="Where",
        inputs=[is_in1_exp, logits1, inner_selected],
        outputs=[selected_logits],
    )
    scale_mul = gs.Node(
        op="Mul", inputs=[selected_logits, scale_const_node], outputs=[scaled_logits]
    )
    softmax_final = gs.Node(
        op="Softmax", inputs=[scaled_logits], outputs=[final_output], attrs={"axis": 1}
    )

    # Combined graph: nodes from both + new nodes; inputs: image + demographics; output: final_output
    combined_graph = gs.Graph(
        nodes=graph1.nodes
        + graph2.nodes
        + [
            argmax1,
            unsqueeze1,
            equal1,
            cast1_node,
            reducesum1,
            greater1,
            argmax2,
            unsqueeze2,
            equal2,
            cast2_node,
            reducesum2,
            greater2,
            shape_node,
            expand1,
            expand2,
            add_logits,
            avg_node,
            inner_where,
            outer_where,
            scale_mul,
            softmax_final,
        ],
        inputs=[image_input, demographics_input],
        outputs=[final_output],
    )

    # Set opset on the graph for LayerNormalization support (opset 17+)
    combined_graph.opset = 17

    # Cleanup and export
    combined_model = gs.export_onnx(combined_graph.cleanup())

    # Infer shapes to fill in any missing (helps checker)
    combined_model = shape_inference.infer_shapes(combined_model)

    # Optional: Check model
    onnx.checker.check_model(combined_model)

    # Save
    onnx.save(combined_model, output_path)
    print(f"Combined ONNX model saved to {output_path}")
    print(f"Output shape: {output_shape}")

    return combined_model


# Usage example (replace with your actual whitelists)
whitelist1=[1, 2, 4, 8, 9]
whitelist2=[0, 3, 10, 5, 6, 7]
combined = create_combined_onnx(
    "../../models/2025-11-27/speechmaster/62_model94.onnx",
    "../../models/2025-11-27/speechmaster/18_model118.onnx",
    whitelist1,
    whitelist2,
    3.0,
    "../../models/combine/2025-11-27/62vs18.onnx",
)

[Model1] Total ReduceMean nodes: 32, Fixed: 0
  - /model/base_model/blocks/blocks.0/blocks.0.0/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /model/base_model/blocks/blocks.0/blocks.0.1/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /model/base_model/blocks/blocks.1/blocks.1.0/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  ... and 29 more
[Model2] Total ReduceMean nodes: 34, Fixed: 0
  - /model/base_model/blocks/blocks.0/blocks.0.0/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /model/base_model/blocks/blocks.0/blocks.0.1/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /model/base_model/blocks/blocks.1/blocks.1.0/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  ... and 31 more
No reduction fixes applied - check debug output above
Inferred output shape: [None, 11]
Combined ONNX model saved to ../..

#### Strategy 2

In [57]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
from onnx import shape_inference
from typing import List, Tuple


def _fix_reduction_nodes(graph: gs.Graph) -> None:
    for node in graph.nodes:
        if node.op in {"ReduceMean", "ReduceL2"} and len(node.inputs) == 2:
            data, axes = node.inputs
            const = next(
                (n for n in graph.nodes if n.op == "Constant" and n.outputs[0] is axes),
                None,
            )
            if const and hasattr(const.attrs["value"], "values"):
                node.inputs = [data]
                node.attrs["axes"] = const.attrs["value"].values.tolist()
                node.attrs["keepdims"] = 1


def _share_inputs_and_rename(g1: gs.Graph, g2: gs.Graph, prefix: str = "m2_"):
    img = g1.inputs[0]
    demo = g1.inputs[1]
    img.name = "image"
    demo.name = "demographics"

    old_img, old_demo = g2.inputs[0], g2.inputs[1]
    skip = {id(old_img), id(old_demo)}
    for v in g2.tensors().values():
        if id(v) not in skip and v.name:
            v.name = prefix + v.name
    for n in g2.nodes:
        if n.name:
            n.name = prefix + n.name

    for n in g2.nodes:
        for i, inp in enumerate(n.inputs):
            if inp is old_img:
                n.inputs[i] = img
            if inp is old_demo:
                n.inputs[i] = demo
    g2.inputs = [img, demo]
    return img, demo


def create_combined_onnx(
    model_path1: str,
    model_path2: str,
    whitelist1: List[int],
    whitelist2: List[int],
    scale_factor: float = 3.0,
    output_path: str = "combined.onnx",
) -> onnx.ModelProto:
    # Load & prepare graphs
    g1 = gs.import_onnx(onnx.load(model_path1))
    g2 = gs.import_onnx(onnx.load(model_path2))
    _fix_reduction_nodes(g1)
    _fix_reduction_nodes(g2)
    img_input, demo_input = _share_inputs_and_rename(g1, g2, "m2_")

    logits1 = g1.outputs[0]
    logits2 = g2.outputs[0]
    logits1.name = "logits_m1"
    logits2.name = "logits_m2"

    num_classes = logits1.shape[1] if logits1.shape and len(logits1.shape) == 2 else 11
    batch_shape = (None, num_classes)

    # Constants
    w1 = gs.Constant("w1", np.array(whitelist1, np.int64))
    w2 = gs.Constant("w2", np.array(whitelist2, np.int64))
    zero = gs.Constant("zero", np.array(0.0, np.float32))
    two = gs.Constant("two", np.array(2.0, np.float32))
    scale = gs.Constant("scale", np.array(scale_factor, np.float32))
    axis1 = gs.Constant("axis1", np.array([1], np.int64))

    # Helper: membership test that works with ReduceSum
    def membership_nodes(
        class_unsq: gs.Variable, whitelist_const: gs.Constant, prefix: str
    ):
        eq = gs.Variable(f"{prefix}_eq", dtype=np.bool)
        cast = gs.Variable(f"{prefix}_cast", dtype=np.float32)
        reduced = gs.Variable(f"{prefix}_reduced", dtype=np.float32, shape=["batch", 1])
        is_member = gs.Variable(f"{prefix}_member", dtype=np.bool, shape=["batch", 1])

        nodes = [
            gs.Node(
                op="Equal",
                name=f"eq_{prefix}",
                inputs=[class_unsq, whitelist_const],
                outputs=[eq],
            ),
            gs.Node(
                op="Cast",
                name=f"cast_{prefix}",
                inputs=[eq],
                outputs=[cast],
                attrs={"to": onnx.TensorProto.FLOAT},
            ),
            gs.Node(
                op="ReduceSum",
                name=f"red_{prefix}",
                inputs=[cast, axis1],
                outputs=[reduced],
                attrs={"keepdims": 1},
            ),
            gs.Node(
                op="Greater",
                name=f"gt_{prefix}",
                inputs=[reduced, zero],
                outputs=[is_member],
            ),
        ]
        return nodes, is_member

    # ArgMax + Unsqueeze
    c1 = gs.Variable("c1", np.int64, ["batch"])
    c2 = gs.Variable("c2", np.int64, ["batch"])
    c1_u = gs.Variable("c1_u", np.int64, ["batch", 1])
    c2_u = gs.Variable("c2_u", np.int64, ["batch", 1])

    nodes = [
        gs.Node(
            op="ArgMax",
            name="argmax1",
            inputs=[logits1],
            outputs=[c1],
            attrs={"axis": 1, "keepdims": 0},
        ),
        gs.Node(
            op="ArgMax",
            name="argmax2",
            inputs=[logits2],
            outputs=[c2],
            attrs={"axis": 1, "keepdims": 0},
        ),
        gs.Node(op="Unsqueeze", name="unsq1", inputs=[c1, axis1], outputs=[c1_u]),
        gs.Node(op="Unsqueeze", name="unsq2", inputs=[c2, axis1], outputs=[c2_u]),
    ]

    # Membership checks
    n1, m1_in_w1 = membership_nodes(c1_u, w1, "m1_w1")
    n2, m1_in_w2 = membership_nodes(c1_u, w2, "m1_w2")
    n3, m2_in_w1 = membership_nodes(c2_u, w1, "m2_w1")
    n4, m2_in_w2 = membership_nodes(c2_u, w2, "m2_w2")
    nodes += n1 + n2 + n3 + n4

    # Conditions
    m1_not_w2 = gs.Variable("m1_not_w2", np.bool, ["batch", 1])
    m2_not_w1 = gs.Variable("m2_not_w1", np.bool, ["batch", 1])
    use_m1 = gs.Variable("use_model1", np.bool, ["batch", 1])
    use_m2 = gs.Variable("use_model2", np.bool, ["batch", 1])

    nodes += [
        gs.Node(op="Not", name="not1", inputs=[m1_in_w2], outputs=[m1_not_w2]),
        gs.Node(op="Not", name="not2", inputs=[m2_in_w1], outputs=[m2_not_w1]),
        gs.Node(
            op="And", name="and_m1", inputs=[m1_in_w1, m1_not_w2], outputs=[use_m1]
        ),
        gs.Node(
            op="And", name="and_m2", inputs=[m2_not_w1, m2_in_w2], outputs=[use_m2]
        ),
    ]

    # Expand masks
    shape_var = gs.Variable("shape", np.int64, [2])
    use_m1_exp = gs.Variable("use_m1_exp", np.bool, batch_shape)
    use_m2_exp = gs.Variable("use_m2_exp", np.bool, batch_shape)

    nodes += [
        gs.Node(op="Shape", name="shape", inputs=[logits1], outputs=[shape_var]),
        gs.Node(
            op="Expand", name="exp_m1", inputs=[use_m1, shape_var], outputs=[use_m1_exp]
        ),
        gs.Node(
            op="Expand", name="exp_m2", inputs=[use_m2, shape_var], outputs=[use_m2_exp]
        ),
    ]

    # Average fallback
    sum_ab = gs.Variable("sum", np.float32, batch_shape)
    avg = gs.Variable("avg", np.float32, batch_shape)
    temp = gs.Variable("temp", np.float32, batch_shape)
    selected = gs.Variable("selected", np.float32, batch_shape)
    scaled = gs.Variable("scaled", np.float32, batch_shape)
    probs = gs.Variable("probabilities", np.float32, batch_shape)

    nodes += [
        gs.Node(op="Add", name="add", inputs=[logits1, logits2], outputs=[sum_ab]),
        gs.Node(op="Div", name="div", inputs=[sum_ab, two], outputs=[avg]),
        gs.Node(
            op="Where",
            name="where_m2",
            inputs=[use_m2_exp, logits2, avg],
            outputs=[temp],
        ),
        gs.Node(
            op="Where",
            name="where_final",
            inputs=[use_m1_exp, logits1, temp],
            outputs=[selected],
        ),
        gs.Node(op="Mul", name="mul", inputs=[selected, scale], outputs=[scaled]),
        gs.Node(
            op="Softmax",
            name="softmax",
            inputs=[scaled],
            outputs=[probs],
            attrs={"axis": 1},
        ),
    ]

    # Final graph
    graph = gs.Graph(
        nodes=g1.nodes + g2.nodes + nodes,
        inputs=[img_input, demo_input],
        outputs=[probs],
        opset=17,
    )

    model = gs.export_onnx(graph.cleanup().toposort())
    model = shape_inference.infer_shapes(model)
    onnx.checker.check_model(model, full_check=True)  # now passes!
    onnx.save(model, output_path)
    print(f"Combined model saved: {output_path}")
    # return model


# === RUN ===
create_combined_onnx(
    model_path1="../../models/2025-11-27/speechmaster/62_model94.onnx",
    model_path2="../../models/2025-11-27/speechmaster/18_model118.onnx",
    whitelist1=[1, 2, 4, 8, 9],
    whitelist2=[0, 3, 10, 5, 6, 7],
    scale_factor=3.0,
    output_path="../../models/combine/2025-11-27/62vs18_exclusive.onnx",
)

Combined model saved: ../../models/combine/2025-11-27/62vs18_exclusive.onnx


### Combining 18 vs 122

#### Strategy 1

In [61]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
from onnx import shape_inference
from typing import List


def fix_reduction_nodes(graph: gs.Graph, graph_name: str = "unknown"):
    """
    Fixes ReduceL2 and ReduceMean nodes that incorrectly have axes as input (2 inputs) by moving axes to attribute.
    Searches for the Constant node producing the axes Variable and extracts its value.
    Removes the axes input and the unused Constant node after fix.
    Adds debug prints for all ReduceL2 and ReduceMean nodes.
    """
    fixed_count = 0
    removed_constants = 0
    debug_nodes = []
    for node in graph.nodes:
        if node.op in ["ReduceL2", "ReduceMean"]:
            debug_nodes.append(
                {
                    "name": node.name,
                    "op": node.op,
                    "inputs_count": len(node.inputs),
                    "inputs_types": [type(inp).__name__ for inp in node.inputs],
                    "second_input_name": (
                        node.inputs[1].name if len(node.inputs) > 1 else None
                    ),
                }
            )
            if len(node.inputs) == 2:
                data_input = node.inputs[0]
                axes_var = node.inputs[1]
                # Search for Constant node producing axes_var
                constant_node = None
                axes_values = None
                for c_node in graph.nodes:
                    if (
                        c_node.op == "Constant"
                        and c_node.outputs
                        and len(c_node.outputs) == 1
                        and c_node.outputs[0].name == axes_var.name
                    ):
                        constant_node = c_node
                        if "value" in c_node.attrs:
                            axes_values = c_node.attrs["value"].values
                            if isinstance(axes_values, np.ndarray):
                                axes_values = axes_values.tolist()
                        break
                if constant_node and axes_values is not None:
                    # Update node: remove second input, add axes attr
                    node.inputs = [data_input]
                    node.attrs["axes"] = axes_values
                    # Ensure keepdims is set (default 1 for most reductions)
                    if "keepdims" not in node.attrs:
                        node.attrs["keepdims"] = 1
                    fixed_count += 1
                    print(
                        f"[{graph_name}] Fixed {node.op} node '{node.name}': axes {axes_values} extracted from Constant '{constant_node.name}'"
                    )
                    # Mark for removal; cleanup will handle unused nodes
                    removed_constants += 1
                else:
                    print(
                        f"[{graph_name}] Warning: Could not find/extract axes for {node.op} '{node.name}'; second input '{axes_var.name}', Constant found: {constant_node is not None}"
                    )
    if debug_nodes:
        print(
            f"[{graph_name}] Total {', '.join(set(dn['op'] for dn in debug_nodes))} nodes: {len(debug_nodes)}, Fixed: {fixed_count}"
        )
        for dn in debug_nodes[:3]:  # Print first 3 for brevity
            print(
                f"  - {dn['name']}: {dn['op']}, {dn['inputs_count']} inputs, types: {dn['inputs_types']}, second_name: {dn['second_input_name']}"
            )
        if len(debug_nodes) > 3:
            print(f"  ... and {len(debug_nodes)-3} more")
    return fixed_count


def _rename_graph_tensors_and_nodes(
    graph: gs.Graph, prefix: str, skip_vars: List[gs.Variable] = None
):
    """Prefix all tensor and node names in `graph` with `prefix`, except variables in skip_vars.

    This avoids name collisions when combining multiple graphs. We compare skip_vars by object id to
    ensure we don't rename the shared input Variable object.
    """
    if skip_vars is None:
        skip_vars = []
    skip_ids = {id(v) for v in skip_vars}

    # Rename variables (tensors)
    tensors = list(graph.tensors().values())
    for var in tensors:
        if id(var) in skip_ids:
            continue
        if var.name:
            var.name = prefix + var.name

    # Rename nodes
    for node in graph.nodes:
        if node.name:
            node.name = prefix + node.name


def create_combined_onnx(
    model_path1,
    model_path2,
    whitelist1: List[int],
    whitelist2: List[int],
    scale_const: float = 3.0,
    output_path="combined.onnx",
):
    """
    Combines two ONNX models into one with conditional logic based on whitelists for skin cancer strategy:
    - Compute logits1 from model1.
    - If argmax(logits1) in whitelist1, use logits1.
    - Else compute logits2 from model2.
    - If argmax(logits2) in whitelist2, use logits2.
    - Else use average of logits1 and logits2.
    - Then scale the chosen logits by scale_const and apply softmax for output.

    Assumes:
    - Both models take 'image' and 'demographics' as inputs.
    - Both output a single tensor of shape [batch_size, num_classes] (logits).
    - Whitelists are lists of class indices (integers).
    - You need to pip install onnx onnx-graphsurgeon if not already installed.
    """
    # Load the models
    onnx_model1 = onnx.load(model_path1)
    onnx_model2 = onnx.load(model_path2)

    # Import into graph surgeon
    graph1 = gs.import_onnx(onnx_model1)
    graph2 = gs.import_onnx(onnx_model2)

    # Fix reduction nodes in BOTH graphs for thoroughness
    total_fixed = 0
    total_fixed += fix_reduction_nodes(graph1, "Model1")
    total_fixed += fix_reduction_nodes(graph2, "Model2")
    if total_fixed == 0:
        print("No reduction fixes applied - check debug output above")

    # Assume input names and order based on your code
    # Rename for clarity
    image_input = graph1.inputs[0]
    image_input.name = "image"

    demographics_input = graph1.inputs[1]
    demographics_input.name = "demographics"

    # Share the inputs with model2
    old_image_input = graph2.inputs[0]
    old_demo_input = graph2.inputs[1]

    _rename_graph_tensors_and_nodes(
        graph2, prefix="g2_", skip_vars=[old_image_input, old_demo_input]
    )

    # Replace all references in graph2 nodes to use shared inputs
    for node in graph2.nodes:
        for i in range(len(node.inputs)):
            if node.inputs[i] is old_image_input:
                node.inputs[i] = image_input
            if node.inputs[i] is old_demo_input:
                node.inputs[i] = demographics_input

    # Update graph2's inputs list to use the shared inputs
    graph2.inputs[0] = image_input
    graph2.inputs[1] = demographics_input

    # Get outputs (assume single output each)
    logits1 = graph1.outputs[0]
    logits1.name = "logits1"

    logits2 = graph2.outputs[0]
    logits2.name = "logits2"

    # Extract num_classes from logits1 shape (assume [batch, num_classes]; batch dynamic)
    orig_shape = logits1.shape
    if orig_shape and len(orig_shape) >= 2:
        num_classes = orig_shape[-1]
        if num_classes == 0 or num_classes is None:
            num_classes = 11  # Fallback assumption
        output_shape = [None, num_classes]  # Dynamic batch
    else:
        output_shape = [None, 11]  # Fallback
        num_classes = 11
        print(
            f"Warning: Could not infer num_classes from shape {orig_shape}; using fallback [None, 11]"
        )

    print(f"Inferred output shape: {output_shape}")

    # Define variables
    class1 = gs.Variable("class1", shape=[None], dtype=np.int64)
    class1_unsq = gs.Variable("class1_unsq", shape=[None, 1], dtype=np.int64)
    eq1 = gs.Variable("eq1", shape=[None, len(whitelist1)], dtype=np.bool)
    cast1 = gs.Variable("cast1", shape=[None, len(whitelist1)], dtype=np.float32)
    reduce1 = gs.Variable("reduce1", shape=[None, 1], dtype=np.float32)
    is_in1 = gs.Variable("is_in1", shape=[None, 1], dtype=np.bool)
    is_in1_exp = gs.Variable("is_in1_exp", shape=output_shape, dtype=np.bool)

    class2 = gs.Variable("class2", shape=[None], dtype=np.int64)
    class2_unsq = gs.Variable("class2_unsq", shape=[None, 1], dtype=np.int64)
    eq2 = gs.Variable("eq2", shape=[None, len(whitelist2)], dtype=np.bool)
    cast2 = gs.Variable("cast2", shape=[None, len(whitelist2)], dtype=np.float32)
    reduce2 = gs.Variable("reduce2", shape=[None, 1], dtype=np.float32)
    is_in2 = gs.Variable("is_in2", shape=[None, 1], dtype=np.bool)
    is_in2_exp = gs.Variable("is_in2_exp", shape=output_shape, dtype=np.bool)

    logits_shape_var = gs.Variable("logits_shape", dtype=np.int64, shape=[2])

    add_output = gs.Variable("add_output", shape=output_shape, dtype=np.float32)
    avg_logits = gs.Variable("avg_logits", shape=output_shape, dtype=np.float32)
    inner_selected = gs.Variable("inner_selected", shape=output_shape, dtype=np.float32)
    selected_logits = gs.Variable(
        "selected_logits", shape=output_shape, dtype=np.float32
    )
    scaled_logits = gs.Variable("scaled_logits", shape=output_shape, dtype=np.float32)
    final_output = gs.Variable("final_output", shape=output_shape, dtype=np.float32)

    # Constants
    whitelist1_const = gs.Constant(
        "whitelist1", values=np.array(whitelist1, dtype=np.int64)
    )
    whitelist2_const = gs.Constant(
        "whitelist2", values=np.array(whitelist2, dtype=np.int64)
    )
    zero_const = gs.Constant("zero", values=np.array(0.0, dtype=np.float32))
    two_const = gs.Constant("two", values=np.array(2.0, dtype=np.float32))
    scale_const_node = gs.Constant(
        "scale", values=np.array(scale_const, dtype=np.float32)
    )
    axes_unsq = gs.Constant("axes_unsq", values=np.array([1], dtype=np.int64))
    axes_reduce = gs.Constant("axes_reduce", values=np.array([1], dtype=np.int64))

    # Nodes for whitelist1 check
    argmax1 = gs.Node(
        op="ArgMax", inputs=[logits1], outputs=[class1], attrs={"axis": 1, "keepdims": 0}
    )
    unsqueeze1 = gs.Node(
        op="Unsqueeze", inputs=[class1, axes_unsq], outputs=[class1_unsq]
    )
    equal1 = gs.Node(op="Equal", inputs=[class1_unsq, whitelist1_const], outputs=[eq1])
    cast1_node = gs.Node(
        op="Cast", inputs=[eq1], outputs=[cast1], attrs={"to": onnx.TensorProto.FLOAT}
    )
    reducesum1 = gs.Node(
        op="ReduceSum",
        inputs=[cast1, axes_reduce],
        outputs=[reduce1],
        attrs={"keepdims": 1},
    )
    greater1 = gs.Node(op="Greater", inputs=[reduce1, zero_const], outputs=[is_in1])

    # Nodes for whitelist2 check
    argmax2 = gs.Node(
        op="ArgMax", inputs=[logits2], outputs=[class2], attrs={"axis": 1, "keepdims": 0}
    )
    unsqueeze2 = gs.Node(
        op="Unsqueeze", inputs=[class2, axes_unsq], outputs=[class2_unsq]
    )
    equal2 = gs.Node(op="Equal", inputs=[class2_unsq, whitelist2_const], outputs=[eq2])
    cast2_node = gs.Node(
        op="Cast", inputs=[eq2], outputs=[cast2], attrs={"to": onnx.TensorProto.FLOAT}
    )
    reducesum2 = gs.Node(
        op="ReduceSum",
        inputs=[cast2, axes_reduce],
        outputs=[reduce2],
        attrs={"keepdims": 1},
    )
    greater2 = gs.Node(op="Greater", inputs=[reduce2, zero_const], outputs=[is_in2])

    # Dynamic shape for expand
    shape_node = gs.Node(op="Shape", inputs=[logits1], outputs=[logits_shape_var])

    expand1 = gs.Node(
        op="Expand", inputs=[is_in1, logits_shape_var], outputs=[is_in1_exp]
    )
    expand2 = gs.Node(
        op="Expand", inputs=[is_in2, logits_shape_var], outputs=[is_in2_exp]
    )

    # Combine logic
    add_logits = gs.Node(op="Add", inputs=[logits1, logits2], outputs=[add_output])
    avg_node = gs.Node(op="Div", inputs=[add_output, two_const], outputs=[avg_logits])
    inner_where = gs.Node(
        op="Where",
        inputs=[is_in2_exp, logits2, avg_logits],
        outputs=[inner_selected],
    )
    outer_where = gs.Node(
        op="Where",
        inputs=[is_in1_exp, logits1, inner_selected],
        outputs=[selected_logits],
    )
    scale_mul = gs.Node(
        op="Mul", inputs=[selected_logits, scale_const_node], outputs=[scaled_logits]
    )
    softmax_final = gs.Node(
        op="Softmax", inputs=[scaled_logits], outputs=[final_output], attrs={"axis": 1}
    )

    # Combined graph: nodes from both + new nodes; inputs: image + demographics; output: final_output
    combined_graph = gs.Graph(
        nodes=graph1.nodes
        + graph2.nodes
        + [
            argmax1,
            unsqueeze1,
            equal1,
            cast1_node,
            reducesum1,
            greater1,
            argmax2,
            unsqueeze2,
            equal2,
            cast2_node,
            reducesum2,
            greater2,
            shape_node,
            expand1,
            expand2,
            add_logits,
            avg_node,
            inner_where,
            outer_where,
            scale_mul,
            softmax_final,
        ],
        inputs=[image_input, demographics_input],
        outputs=[final_output],
    )

    # Set opset on the graph for LayerNormalization support (opset 17+)
    combined_graph.opset = 17

    # Cleanup and export
    combined_model = gs.export_onnx(combined_graph.cleanup())

    # Infer shapes to fill in any missing (helps checker)
    combined_model = shape_inference.infer_shapes(combined_model)

    # Optional: Check model
    onnx.checker.check_model(combined_model)

    # Save
    onnx.save(combined_model, output_path)
    print(f"Combined ONNX model saved to {output_path}")
    print(f"Output shape: {output_shape}")

    return combined_model


# Usage example (replace with your actual whitelists)
whitelist18=[0, 3, 4, 10]
whitelist122=[1, 2, 5, 8, 9]  # Example: classes for model2
# whitelist1=[0, 3, 10, 9, 6, 7]
# whitelist2=[1, 2, 4, 5, 8]  # Example: classes for model2
combined = create_combined_onnx(
    "../../models/2025-11-27/speechmaster/18_model118.onnx",
    "../../models/2025-11-27/speechmaster/122_model123.onnx",
    whitelist1=whitelist18,
    whitelist2=whitelist122,
    scale_const=3.0,
    output_path="../../models/combine/2025-11-27/18vs122.onnx",
)

[Model1] Total ReduceMean nodes: 34, Fixed: 0
  - /model/base_model/blocks/blocks.0/blocks.0.0/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /model/base_model/blocks/blocks.0/blocks.0.1/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /model/base_model/blocks/blocks.1/blocks.1.0/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  ... and 31 more
[Model2] Total ReduceMean nodes: 34, Fixed: 0
  - /model/base_model/blocks/blocks.0/blocks.0.0/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /model/base_model/blocks/blocks.0/blocks.0.1/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /model/base_model/blocks/blocks.1/blocks.1.0/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  ... and 31 more
No reduction fixes applied - check debug output above
Inferred output shape: [None, 11]
Combined ONNX model saved to ../..

#### Strategy 2

In [62]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
from onnx import shape_inference
from typing import List, Tuple


def _fix_reduction_nodes(graph: gs.Graph) -> None:
    for node in graph.nodes:
        if node.op in {"ReduceMean", "ReduceL2"} and len(node.inputs) == 2:
            data, axes = node.inputs
            const = next((n for n in graph.nodes if n.op == "Constant" and n.outputs[0] is axes), None)
            if const and hasattr(const.attrs["value"], "values"):
                node.inputs = [data]
                node.attrs["axes"] = const.attrs["value"].values.tolist()
                node.attrs["keepdims"] = 1


def _share_inputs_and_rename(g1: gs.Graph, g2: gs.Graph, prefix: str = "m2_"):
    img = g1.inputs[0]
    demo = g1.inputs[1]
    img.name = "image"
    demo.name = "demographics"

    old_img, old_demo = g2.inputs[0], g2.inputs[1]
    skip = {id(old_img), id(old_demo)}
    for v in g2.tensors().values():
        if id(v) not in skip and v.name:
            v.name = prefix + v.name
    for n in g2.nodes:
        if n.name:
            n.name = prefix + n.name

    for n in g2.nodes:
        for i, inp in enumerate(n.inputs):
            if inp is old_img:
                n.inputs[i] = img
            if inp is old_demo:
                n.inputs[i] = demo
    g2.inputs = [img, demo]
    return img, demo


def create_combined_onnx(
    model_path1: str,
    model_path2: str,
    whitelist1: List[int],
    whitelist2: List[int],
    scale_factor: float = 3.0,
    output_path: str = "combined.onnx",
) -> onnx.ModelProto:
    # Load & prepare graphs
    g1 = gs.import_onnx(onnx.load(model_path1))
    g2 = gs.import_onnx(onnx.load(model_path2))
    _fix_reduction_nodes(g1)
    _fix_reduction_nodes(g2)
    img_input, demo_input = _share_inputs_and_rename(g1, g2, "m2_")

    logits1 = g1.outputs[0]
    logits2 = g2.outputs[0]
    logits1.name = "logits_m1"
    logits2.name = "logits_m2"

    num_classes = logits1.shape[1] if logits1.shape and len(logits1.shape) == 2 else 11
    batch_shape = (None, num_classes)

    # Constants
    w1 = gs.Constant("w1", np.array(whitelist1, np.int64))
    w2 = gs.Constant("w2", np.array(whitelist2, np.int64))
    zero = gs.Constant("zero", np.array(0.0, np.float32))
    two = gs.Constant("two", np.array(2.0, np.float32))
    scale = gs.Constant("scale", np.array(scale_factor, np.float32))
    axis1 = gs.Constant("axis1", np.array([1], np.int64))

    # Helper: membership test that works with ReduceSum
    def membership_nodes(class_unsq: gs.Variable, whitelist_const: gs.Constant, prefix: str):
        eq = gs.Variable(f"{prefix}_eq", dtype=np.bool)
        cast = gs.Variable(f"{prefix}_cast", dtype=np.float32)
        reduced = gs.Variable(f"{prefix}_reduced", dtype=np.float32, shape=["batch", 1])
        is_member = gs.Variable(f"{prefix}_member", dtype=np.bool, shape=["batch", 1])

        nodes = [
            gs.Node(op="Equal", name=f"eq_{prefix}", inputs=[class_unsq, whitelist_const], outputs=[eq]),
            gs.Node(op="Cast", name=f"cast_{prefix}", inputs=[eq], outputs=[cast],
                    attrs={"to": onnx.TensorProto.FLOAT}),
            gs.Node(op="ReduceSum", name=f"red_{prefix}", inputs=[cast, axis1],
                    outputs=[reduced], attrs={"keepdims": 1}),
            gs.Node(op="Greater", name=f"gt_{prefix}", inputs=[reduced, zero], outputs=[is_member]),
        ]
        return nodes, is_member

    # ArgMax + Unsqueeze
    c1 = gs.Variable("c1", np.int64, ["batch"])
    c2 = gs.Variable("c2", np.int64, ["batch"])
    c1_u = gs.Variable("c1_u", np.int64, ["batch", 1])
    c2_u = gs.Variable("c2_u", np.int64, ["batch", 1])

    nodes = [
        gs.Node(op="ArgMax", name="argmax1", inputs=[logits1], outputs=[c1], attrs={"axis": 1, "keepdims": 0}),
        gs.Node(op="ArgMax", name="argmax2", inputs=[logits2], outputs=[c2], attrs={"axis": 1, "keepdims": 0}),
        gs.Node(op="Unsqueeze", name="unsq1", inputs=[c1, axis1], outputs=[c1_u]),
        gs.Node(op="Unsqueeze", name="unsq2", inputs=[c2, axis1], outputs=[c2_u]),
    ]

    # Membership checks
    n1, m1_in_w1 = membership_nodes(c1_u, w1, "m1_w1")
    n2, m1_in_w2 = membership_nodes(c1_u, w2, "m1_w2")
    n3, m2_in_w1 = membership_nodes(c2_u, w1, "m2_w1")
    n4, m2_in_w2 = membership_nodes(c2_u, w2, "m2_w2")
    nodes += n1 + n2 + n3 + n4

    # Conditions
    m1_not_w2 = gs.Variable("m1_not_w2", np.bool, ["batch", 1])
    m2_not_w1 = gs.Variable("m2_not_w1", np.bool, ["batch", 1])
    use_m1 = gs.Variable("use_model1", np.bool, ["batch", 1])
    use_m2 = gs.Variable("use_model2", np.bool, ["batch", 1])

    nodes += [
        gs.Node(op="Not", name="not1", inputs=[m1_in_w2], outputs=[m1_not_w2]),
        gs.Node(op="Not", name="not2", inputs=[m2_in_w1], outputs=[m2_not_w1]),
        gs.Node(op="And", name="and_m1", inputs=[m1_in_w1, m1_not_w2], outputs=[use_m1]),
        gs.Node(op="And", name="and_m2", inputs=[m2_not_w1, m2_in_w2], outputs=[use_m2]),
    ]

    # Expand masks
    shape_var = gs.Variable("shape", np.int64, [2])
    use_m1_exp = gs.Variable("use_m1_exp", np.bool, batch_shape)
    use_m2_exp = gs.Variable("use_m2_exp", np.bool, batch_shape)

    nodes += [
        gs.Node(op="Shape", name="shape", inputs=[logits1], outputs=[shape_var]),
        gs.Node(op="Expand", name="exp_m1", inputs=[use_m1, shape_var], outputs=[use_m1_exp]),
        gs.Node(op="Expand", name="exp_m2", inputs=[use_m2, shape_var], outputs=[use_m2_exp]),
    ]

    # Average fallback
    sum_ab = gs.Variable("sum", np.float32, batch_shape)
    avg = gs.Variable("avg", np.float32, batch_shape)
    temp = gs.Variable("temp", np.float32, batch_shape)
    selected = gs.Variable("selected", np.float32, batch_shape)
    scaled = gs.Variable("scaled", np.float32, batch_shape)
    probs = gs.Variable("probabilities", np.float32, batch_shape)

    nodes += [
        gs.Node(op="Add", name="add", inputs=[logits1, logits2], outputs=[sum_ab]),
        gs.Node(op="Div", name="div", inputs=[sum_ab, two], outputs=[avg]),
        gs.Node(op="Where", name="where_m2", inputs=[use_m2_exp, logits2, avg], outputs=[temp]),
        gs.Node(op="Where", name="where_final", inputs=[use_m1_exp, logits1, temp], outputs=[selected]),
        gs.Node(op="Mul", name="mul", inputs=[selected, scale], outputs=[scaled]),
        gs.Node(op="Softmax", name="softmax", inputs=[scaled], outputs=[probs], attrs={"axis": 1}),
    ]

    # Final graph
    graph = gs.Graph(nodes=g1.nodes + g2.nodes + nodes,
                     inputs=[img_input, demo_input],
                     outputs=[probs],
                     opset=17)

    model = gs.export_onnx(graph.cleanup().toposort())
    model = shape_inference.infer_shapes(model)
    onnx.checker.check_model(model, full_check=True)   # now passes!
    onnx.save(model, output_path)
    print(f"Combined model saved: {output_path}")
    # return model


# Usage example (replace with your actual whitelists)
whitelist18=[0, 3, 4, 10]
whitelist122=[1, 2, 5, 8, 9]  
combined = create_combined_onnx(
    "../../models/2025-11-27/speechmaster/122_model123.onnx",
    "../../models/2025-11-27/speechmaster/18_model118.onnx",
    whitelist1=whitelist122,
    whitelist2=whitelist18,
    scale_factor=3.0,
    output_path="../../models/combine/2025-11-27/122vs18_exclusive.onnx",
)

Combined model saved: ../../models/combine/2025-11-27/122vs18_exclusive.onnx


### Combining 148 vs 196

In [None]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
from onnx import shape_inference
from typing import List


def fix_reduction_nodes(graph: gs.Graph, graph_name: str = "unknown"):
    """
    Fixes ReduceL2 and ReduceMean nodes that incorrectly have axes as input (2 inputs) by moving axes to attribute.
    Searches for the Constant node producing the axes Variable and extracts its value.
    Removes the axes input and the unused Constant node after fix.
    Adds debug prints for all ReduceL2 and ReduceMean nodes.
    """
    fixed_count = 0
    removed_constants = 0
    debug_nodes = []
    for node in graph.nodes:
        if node.op in ['ReduceL2', 'ReduceMean']:
            debug_nodes.append({
                'name': node.name,
                'op': node.op,
                'inputs_count': len(node.inputs),
                'inputs_types': [type(inp).__name__ for inp in node.inputs],
                'second_input_name': node.inputs[1].name if len(node.inputs) > 1 else None
            })
            if len(node.inputs) == 2:
                data_input = node.inputs[0]
                axes_var = node.inputs[1]
                # Search for Constant node producing axes_var
                constant_node = None
                axes_values = None
                for c_node in graph.nodes:
                    if (c_node.op == 'Constant' and 
                        c_node.outputs and len(c_node.outputs) == 1 and 
                        c_node.outputs[0].name == axes_var.name):
                        constant_node = c_node
                        if 'value' in c_node.attrs:
                            axes_values = c_node.attrs['value'].values
                            if isinstance(axes_values, np.ndarray):
                                axes_values = axes_values.tolist()
                        break
                if constant_node and axes_values is not None:
                    # Update node: remove second input, add axes attr
                    node.inputs = [data_input]
                    node.attrs['axes'] = axes_values
                    # Ensure keepdims is set (default 1 for most reductions)
                    if 'keepdims' not in node.attrs:
                        node.attrs['keepdims'] = 1
                    fixed_count += 1
                    print(f"[{graph_name}] Fixed {node.op} node '{node.name}': axes {axes_values} extracted from Constant '{constant_node.name}'")
                    # Mark for removal; cleanup will handle unused nodes
                    removed_constants += 1
                else:
                    print(f"[{graph_name}] Warning: Could not find/extract axes for {node.op} '{node.name}'; second input '{axes_var.name}', Constant found: {constant_node is not None}")
    if debug_nodes:
        print(f"[{graph_name}] Total {', '.join(set(dn['op'] for dn in debug_nodes))} nodes: {len(debug_nodes)}, Fixed: {fixed_count}")
        for dn in debug_nodes[:3]:  # Print first 3 for brevity
            print(f"  - {dn['name']}: {dn['op']}, {dn['inputs_count']} inputs, types: {dn['inputs_types']}, second_name: {dn['second_input_name']}")
        if len(debug_nodes) > 3:
            print(f"  ... and {len(debug_nodes)-3} more")
    return fixed_count


def _rename_graph_tensors_and_nodes(graph: gs.Graph, prefix: str, skip_vars: List[gs.Variable] = None):
    """Prefix all tensor and node names in `graph` with `prefix`, except variables in skip_vars.

    This avoids name collisions when combining multiple graphs. We compare skip_vars by object id to
    ensure we don't rename the shared input Variable object.
    """
    if skip_vars is None:
        skip_vars = []
    skip_ids = {id(v) for v in skip_vars}

    # Rename variables (tensors)
    tensors = list(graph.tensors().values())
    for var in tensors:
        if id(var) in skip_ids:
            continue
        if var.name:
            var.name = prefix + var.name

    # Rename nodes
    for node in graph.nodes:
        if node.name:
            node.name = prefix + node.name


def create_combined_onnx(model_path1, model_path2, output_path='combined.onnx'):
    """
    Combines two ONNX models into one:
    - Model1: takes 'image' and 'demographics' -> logits1
    - Model2: takes 'image' -> logits2
    - Combined: takes 'image' and 'demographics' -> (softmax(logits1) + logits2) / 2

    Key changes vs. earlier: we rename the second graph's tensors/nodes with a prefix to avoid name collisions
    and ensure the shared `image` input variable object is used by both graphs. This prevents duplicate tensor
    names and topological ordering issues during checker validation.
        """
    # Load the models
    onnx_model1 = onnx.load(model_path1)
    onnx_model2 = onnx.load(model_path2)

    # Import into graph surgeon
    graph1 = gs.import_onnx(onnx_model1)
    graph2 = gs.import_onnx(onnx_model2)

    # Fix reduction nodes in BOTH graphs for thoroughness
    total_fixed = 0
    total_fixed += fix_reduction_nodes(graph1, "Model1")
    total_fixed += fix_reduction_nodes(graph2, "Model2")
    if total_fixed == 0:
        print("No reduction fixes applied - check debug output above")

    # Rename for clarity and sharing
    image_input = graph1.inputs[0]
    image_input.name = 'image'

    demographics_input = graph1.inputs[1]
    demographics_input.name = 'demographics'

    # Grab model2's image input object BEFORE renaming so we can skip renaming that specific Variable
    old_image_input = graph2.inputs[0]
    old_demo_input = graph2.inputs[1]

    # Rename graph2 tensors/nodes to avoid clashes (but don't rename the image Variable object)
    _rename_graph_tensors_and_nodes(graph2, prefix='g2_', skip_vars=[old_image_input, old_demo_input])

    # Replace all references in graph2 nodes from old_image_input to the shared image_input object
    for node in graph2.nodes:
        for i in range(len(node.inputs)):
            if node.inputs[i] is old_image_input:
                node.inputs[i] = image_input
            if node.inputs[i] is old_demo_input:
                node.inputs[i] = demographics_input

    # Update graph2's inputs list to use the shared input object (this removes a duplicate input with same name)
    graph2.inputs[0] = image_input
    graph2.inputs[1] = demographics_input

    # Get outputs (assume single output each)
    logits1 = graph1.outputs[0]
    logits1.name = 'logits1'

    logits2 = graph2.outputs[0]
    logits2.name = 'logits2'

    # Extract num_classes from logits1 shape (assume [batch, num_classes]; batch dynamic)
    orig_shape = logits1.shape
    if orig_shape and len(orig_shape) >= 2:
        num_classes = orig_shape[-1]
        if num_classes == 0 or num_classes is None:
            num_classes = 11  # Fallback assumption based on reported output size
        output_shape = [None, num_classes]  # Dynamic batch
    else:
        output_shape = [None, 11]  # Fallback
        num_classes = 11
        print(f"Warning: Could not infer num_classes from shape {orig_shape}; using fallback [None, 10]")

    print(f"Inferred output shape: {output_shape}")

    # Define output variables WITH dtype and shape (no flattening)
    probs1 = gs.Variable('probs1', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    probs2 = gs.Variable('probs2', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    sum_avg = gs.Variable('sum_avg', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    avg_output1 = gs.Variable('avg_output1', shape=output_shape, dtype=onnx.TensorProto.FLOAT)
    avg_output2 = gs.Variable('avg_output2', shape=output_shape, dtype=onnx.TensorProto.FLOAT)

    # Softmax on first model (axis=1 for [batch, classes])
    softmax1 = gs.Node(
        op='Softmax',
        inputs=[logits1],
        outputs=[probs1],
        attrs={'axis': 1}
    )    
    softmax2 = gs.Node(
        op='Softmax',
        inputs=[logits2],
        outputs=[probs2],
        attrs={'axis': 1}
    )

    # Average: (probs1 + logits2) / 2

    constant_07 = gs.Constant(name='constant_07', values=np.array(0.3, dtype=np.float32))  # Scalar for broadcast
    constant_03 = gs.Constant(name='constant_03', values=np.array(0.7, dtype=np.float32))  # Scalar for broadcast
    mul1 = gs.Node(
        op='Mul',  # Equivalent to / 2
        inputs=[probs1, constant_03],
        outputs=[avg_output1]
    )
    mul2 = gs.Node(
        op='Mul',  # Equivalent to / 2
        inputs=[probs2, constant_07],
        outputs=[avg_output2]
    )
    add = gs.Node(
        op='Add',
        inputs=[avg_output1, avg_output2],
        outputs=[sum_avg]
    )

    # Combined graph: nodes from both + new nodes; inputs: image + demographics; output: avg_output
    # We put graph1 nodes first, then graph2 nodes (which we've namespaced) so producers appear before consumers.
    combined_graph = gs.Graph(
        nodes=graph1.nodes + graph2.nodes + [softmax1, softmax2, mul1, mul2, add],
        inputs=[image_input, demographics_input],
        outputs=[sum_avg]
    )

    # Set opset on the graph for LayerNormalization support (opset 17+)
    combined_graph.opset = 17

    # Cleanup and export - cleanup will remove unused nodes and should also fix ordering where possible
    combined_model = gs.export_onnx(combined_graph.cleanup())

    # Infer shapes to fill in any missing (helps checker)
    combined_model = shape_inference.infer_shapes(combined_model)

    # Optional: Check model
    onnx.checker.check_model(combined_model)

    # Save
    onnx.save(combined_model, output_path)
    print(f"Combined ONNX model saved to {output_path}")
    print(f"Output shape: {output_shape}")

    return combined_model

# Usage
# Note: adjust paths as needed
combined = create_combined_onnx('model/medicaldev_148.onnx', 'model/medicaldev_196.onnx', "model/medicaldev_148_196.onnx")


### Combining 61 vs 62

#### Strategy 1

In [54]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
from onnx import shape_inference
from typing import List


def fix_reduction_nodes(graph: gs.Graph, graph_name: str = "unknown"):
    """
    Fixes ReduceL2 and ReduceMean nodes that incorrectly have axes as input (2 inputs) by moving axes to attribute.
    Searches for the Constant node producing the axes Variable and extracts its value.
    Removes the axes input and the unused Constant node after fix.
    Adds debug prints for all ReduceL2 and ReduceMean nodes.
    """
    fixed_count = 0
    removed_constants = 0
    debug_nodes = []
    for node in graph.nodes:
        if node.op in ["ReduceL2", "ReduceMean"]:
            debug_nodes.append(
                {
                    "name": node.name,
                    "op": node.op,
                    "inputs_count": len(node.inputs),
                    "inputs_types": [type(inp).__name__ for inp in node.inputs],
                    "second_input_name": (
                        node.inputs[1].name if len(node.inputs) > 1 else None
                    ),
                }
            )
            if len(node.inputs) == 2:
                data_input = node.inputs[0]
                axes_var = node.inputs[1]
                # Search for Constant node producing axes_var
                constant_node = None
                axes_values = None
                for c_node in graph.nodes:
                    if (
                        c_node.op == "Constant"
                        and c_node.outputs
                        and len(c_node.outputs) == 1
                        and c_node.outputs[0].name == axes_var.name
                    ):
                        constant_node = c_node
                        if "value" in c_node.attrs:
                            axes_values = c_node.attrs["value"].values
                            if isinstance(axes_values, np.ndarray):
                                axes_values = axes_values.tolist()
                        break
                if constant_node and axes_values is not None:
                    # Update node: remove second input, add axes attr
                    node.inputs = [data_input]
                    node.attrs["axes"] = axes_values
                    # Ensure keepdims is set (default 1 for most reductions)
                    if "keepdims" not in node.attrs:
                        node.attrs["keepdims"] = 1
                    fixed_count += 1
                    print(
                        f"[{graph_name}] Fixed {node.op} node '{node.name}': axes {axes_values} extracted from Constant '{constant_node.name}'"
                    )
                    # Mark for removal; cleanup will handle unused nodes
                    removed_constants += 1
                else:
                    print(
                        f"[{graph_name}] Warning: Could not find/extract axes for {node.op} '{node.name}'; second input '{axes_var.name}', Constant found: {constant_node is not None}"
                    )
    if debug_nodes:
        print(
            f"[{graph_name}] Total {', '.join(set(dn['op'] for dn in debug_nodes))} nodes: {len(debug_nodes)}, Fixed: {fixed_count}"
        )
        for dn in debug_nodes[:3]:  # Print first 3 for brevity
            print(
                f"  - {dn['name']}: {dn['op']}, {dn['inputs_count']} inputs, types: {dn['inputs_types']}, second_name: {dn['second_input_name']}"
            )
        if len(debug_nodes) > 3:
            print(f"  ... and {len(debug_nodes)-3} more")
    return fixed_count


def _rename_graph_tensors_and_nodes(
    graph: gs.Graph, prefix: str, skip_vars: List[gs.Variable] = None
):
    """Prefix all tensor and node names in `graph` with `prefix`, except variables in skip_vars.

    This avoids name collisions when combining multiple graphs. We compare skip_vars by object id to
    ensure we don't rename the shared input Variable object.
    """
    if skip_vars is None:
        skip_vars = []
    skip_ids = {id(v) for v in skip_vars}

    # Rename variables (tensors)
    tensors = list(graph.tensors().values())
    for var in tensors:
        if id(var) in skip_ids:
            continue
        if var.name:
            var.name = prefix + var.name

    # Rename nodes
    for node in graph.nodes:
        if node.name:
            node.name = prefix + node.name


def create_combined_onnx(
    model_path1,
    model_path2,
    whitelist1: List[int],
    whitelist2: List[int],
    scale_const: float = 3.0,
    output_path="combined.onnx",
):
    """
    Combines two ONNX models into one with conditional logic based on whitelists for skin cancer strategy:
    - Compute logits1 from model1.
    - If argmax(logits1) in whitelist1, use logits1.
    - Else compute logits2 from model2.
    - If argmax(logits2) in whitelist2, use logits2.
    - Else use average of logits1 and logits2.
    - Then scale the chosen logits by scale_const and apply softmax for output.

    Assumes:
    - Both models take 'image' and 'demographics' as inputs.
    - Both output a single tensor of shape [batch_size, num_classes] (logits).
    - Whitelists are lists of class indices (integers).
    - You need to pip install onnx onnx-graphsurgeon if not already installed.
    """
    # Load the models
    onnx_model1 = onnx.load(model_path1)
    onnx_model2 = onnx.load(model_path2)

    # Import into graph surgeon
    graph1 = gs.import_onnx(onnx_model1)
    graph2 = gs.import_onnx(onnx_model2)

    # Fix reduction nodes in BOTH graphs for thoroughness
    total_fixed = 0
    total_fixed += fix_reduction_nodes(graph1, "Model1")
    total_fixed += fix_reduction_nodes(graph2, "Model2")
    if total_fixed == 0:
        print("No reduction fixes applied - check debug output above")

    # Assume input names and order based on your code
    # Rename for clarity
    image_input = graph1.inputs[0]
    image_input.name = "image"

    demographics_input = graph1.inputs[1]
    demographics_input.name = "demographics"

    # Share the inputs with model2
    old_image_input = graph2.inputs[0]
    old_demo_input = graph2.inputs[1]

    _rename_graph_tensors_and_nodes(
        graph2, prefix="g2_", skip_vars=[old_image_input, old_demo_input]
    )

    # Replace all references in graph2 nodes to use shared inputs
    for node in graph2.nodes:
        for i in range(len(node.inputs)):
            if node.inputs[i] is old_image_input:
                node.inputs[i] = image_input
            if node.inputs[i] is old_demo_input:
                node.inputs[i] = demographics_input

    # Update graph2's inputs list to use the shared inputs
    graph2.inputs[0] = image_input
    graph2.inputs[1] = demographics_input

    # Get outputs (assume single output each)
    logits1 = graph1.outputs[0]
    logits1.name = "logits1"

    logits2 = graph2.outputs[0]
    logits2.name = "logits2"

    # Extract num_classes from logits1 shape (assume [batch, num_classes]; batch dynamic)
    orig_shape = logits1.shape
    if orig_shape and len(orig_shape) >= 2:
        num_classes = orig_shape[-1]
        if num_classes == 0 or num_classes is None:
            num_classes = 11  # Fallback assumption
        output_shape = [None, num_classes]  # Dynamic batch
    else:
        output_shape = [None, 11]  # Fallback
        num_classes = 11
        print(
            f"Warning: Could not infer num_classes from shape {orig_shape}; using fallback [None, 11]"
        )

    print(f"Inferred output shape: {output_shape}")

    # Define variables
    class1 = gs.Variable("class1", shape=[None], dtype=np.int64)
    class1_unsq = gs.Variable("class1_unsq", shape=[None, 1], dtype=np.int64)
    eq1 = gs.Variable("eq1", shape=[None, len(whitelist1)], dtype=np.bool)
    cast1 = gs.Variable("cast1", shape=[None, len(whitelist1)], dtype=np.float32)
    reduce1 = gs.Variable("reduce1", shape=[None, 1], dtype=np.float32)
    is_in1 = gs.Variable("is_in1", shape=[None, 1], dtype=np.bool)
    is_in1_exp = gs.Variable("is_in1_exp", shape=output_shape, dtype=np.bool)

    class2 = gs.Variable("class2", shape=[None], dtype=np.int64)
    class2_unsq = gs.Variable("class2_unsq", shape=[None, 1], dtype=np.int64)
    eq2 = gs.Variable("eq2", shape=[None, len(whitelist2)], dtype=np.bool)
    cast2 = gs.Variable("cast2", shape=[None, len(whitelist2)], dtype=np.float32)
    reduce2 = gs.Variable("reduce2", shape=[None, 1], dtype=np.float32)
    is_in2 = gs.Variable("is_in2", shape=[None, 1], dtype=np.bool)
    is_in2_exp = gs.Variable("is_in2_exp", shape=output_shape, dtype=np.bool)

    logits_shape_var = gs.Variable("logits_shape", dtype=np.int64, shape=[2])

    add_output = gs.Variable("add_output", shape=output_shape, dtype=np.float32)
    avg_logits = gs.Variable("avg_logits", shape=output_shape, dtype=np.float32)
    inner_selected = gs.Variable("inner_selected", shape=output_shape, dtype=np.float32)
    selected_logits = gs.Variable(
        "selected_logits", shape=output_shape, dtype=np.float32
    )
    scaled_logits = gs.Variable("scaled_logits", shape=output_shape, dtype=np.float32)
    final_output = gs.Variable("final_output", shape=output_shape, dtype=np.float32)

    # Constants
    whitelist1_const = gs.Constant(
        "whitelist1", values=np.array(whitelist1, dtype=np.int64)
    )
    whitelist2_const = gs.Constant(
        "whitelist2", values=np.array(whitelist2, dtype=np.int64)
    )
    zero_const = gs.Constant("zero", values=np.array(0.0, dtype=np.float32))
    two_const = gs.Constant("two", values=np.array(2.0, dtype=np.float32))
    scale_const_node = gs.Constant(
        "scale", values=np.array(scale_const, dtype=np.float32)
    )
    axes_unsq = gs.Constant("axes_unsq", values=np.array([1], dtype=np.int64))
    axes_reduce = gs.Constant("axes_reduce", values=np.array([1], dtype=np.int64))

    # Nodes for whitelist1 check
    argmax1 = gs.Node(
        op="ArgMax", inputs=[logits1], outputs=[class1], attrs={"axis": 1, "keepdims": 0}
    )
    unsqueeze1 = gs.Node(
        op="Unsqueeze", inputs=[class1, axes_unsq], outputs=[class1_unsq]
    )
    equal1 = gs.Node(op="Equal", inputs=[class1_unsq, whitelist1_const], outputs=[eq1])
    cast1_node = gs.Node(
        op="Cast", inputs=[eq1], outputs=[cast1], attrs={"to": onnx.TensorProto.FLOAT}
    )
    reducesum1 = gs.Node(
        op="ReduceSum",
        inputs=[cast1, axes_reduce],
        outputs=[reduce1],
        attrs={"keepdims": 1},
    )
    greater1 = gs.Node(op="Greater", inputs=[reduce1, zero_const], outputs=[is_in1])

    # Nodes for whitelist2 check
    argmax2 = gs.Node(
        op="ArgMax", inputs=[logits2], outputs=[class2], attrs={"axis": 1, "keepdims": 0}
    )
    unsqueeze2 = gs.Node(
        op="Unsqueeze", inputs=[class2, axes_unsq], outputs=[class2_unsq]
    )
    equal2 = gs.Node(op="Equal", inputs=[class2_unsq, whitelist2_const], outputs=[eq2])
    cast2_node = gs.Node(
        op="Cast", inputs=[eq2], outputs=[cast2], attrs={"to": onnx.TensorProto.FLOAT}
    )
    reducesum2 = gs.Node(
        op="ReduceSum",
        inputs=[cast2, axes_reduce],
        outputs=[reduce2],
        attrs={"keepdims": 1},
    )
    greater2 = gs.Node(op="Greater", inputs=[reduce2, zero_const], outputs=[is_in2])

    # Dynamic shape for expand
    shape_node = gs.Node(op="Shape", inputs=[logits1], outputs=[logits_shape_var])

    expand1 = gs.Node(
        op="Expand", inputs=[is_in1, logits_shape_var], outputs=[is_in1_exp]
    )
    expand2 = gs.Node(
        op="Expand", inputs=[is_in2, logits_shape_var], outputs=[is_in2_exp]
    )

    # Combine logic
    add_logits = gs.Node(op="Add", inputs=[logits1, logits2], outputs=[add_output])
    avg_node = gs.Node(op="Div", inputs=[add_output, two_const], outputs=[avg_logits])
    inner_where = gs.Node(
        op="Where",
        inputs=[is_in2_exp, logits2, avg_logits],
        outputs=[inner_selected],
    )
    outer_where = gs.Node(
        op="Where",
        inputs=[is_in1_exp, logits1, inner_selected],
        outputs=[selected_logits],
    )
    scale_mul = gs.Node(
        op="Mul", inputs=[selected_logits, scale_const_node], outputs=[scaled_logits]
    )
    softmax_final = gs.Node(
        op="Softmax", inputs=[scaled_logits], outputs=[final_output], attrs={"axis": 1}
    )

    # Combined graph: nodes from both + new nodes; inputs: image + demographics; output: final_output
    combined_graph = gs.Graph(
        nodes=graph1.nodes
        + graph2.nodes
        + [
            argmax1,
            unsqueeze1,
            equal1,
            cast1_node,
            reducesum1,
            greater1,
            argmax2,
            unsqueeze2,
            equal2,
            cast2_node,
            reducesum2,
            greater2,
            shape_node,
            expand1,
            expand2,
            add_logits,
            avg_node,
            inner_where,
            outer_where,
            scale_mul,
            softmax_final,
        ],
        inputs=[image_input, demographics_input],
        outputs=[final_output],
    )

    # Set opset on the graph for LayerNormalization support (opset 17+)
    combined_graph.opset = 17

    # Cleanup and export
    combined_model = gs.export_onnx(combined_graph.cleanup())

    # Infer shapes to fill in any missing (helps checker)
    combined_model = shape_inference.infer_shapes(combined_model)

    # Optional: Check model
    onnx.checker.check_model(combined_model)

    # Save
    onnx.save(combined_model, output_path)
    print(f"Combined ONNX model saved to {output_path}")
    print(f"Output shape: {output_shape}")

    return combined_model


# # Usage example (replace with your actual whitelists)
# whitelist61=[0, 3, 4, 10]
# whitelist62=[1, 2, 5, 8, 9]  
# combined = create_combined_onnx(
#     "../../models/2025-11-27/speechmaster/122_model123.onnx",
#     "../../models/2025-11-27/speechmaster/18_model118.onnx",
#     whitelist1=whitelist122,
#     whitelist2=whitelist18,
#     scale_count=3.0,
#     output_path="../../models/combine/2025-11-27/122vs18_exclusive.onnx",
# )

# Usage example (replace with your actual whitelists)
whitelist1=[0, 3, 10, 9, 6, 7]
whitelist2=[1, 2, 4, 5, 8]  # Example: classes for model2
# whitelist1=[0, 3, 10, 9, 6, 7]
# whitelist2=[1, 2, 4, 5, 8]  # Example: classes for model2
combined = create_combined_onnx(
    "../../models/2025-11-27/speechmaster/62_model94.onnx",
    "../../models/2025-11-27/grose/61_model08.onnx",
    whitelist1,
    whitelist2,
    3.0,
    "../../models/combine/2025-11-27/62vs61.onnx",
)

[Model1] Total ReduceMean nodes: 32, Fixed: 0
  - /model/base_model/blocks/blocks.0/blocks.0.0/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /model/base_model/blocks/blocks.0/blocks.0.1/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /model/base_model/blocks/blocks.1/blocks.1.0/se/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  ... and 29 more
[Model2] Total ReduceMean nodes: 34, Fixed: 0
  - /cnn/features/features.1/features.1.0/block/block.1/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /cnn/features/features.1/features.1.1/block/block.1/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  - /cnn/features/features.2/features.2.0/block/block.2/ReduceMean: ReduceMean, 1 inputs, types: ['Variable'], second_name: None
  ... and 31 more
No reduction fixes applied - check debug output above
Inferred output shape: [None, 11]
Combined ONNX model sa

#### Strategy 2

In [55]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
from onnx import shape_inference
from typing import List, Tuple


def _fix_reduction_nodes(graph: gs.Graph) -> None:
    for node in graph.nodes:
        if node.op in {"ReduceMean", "ReduceL2"} and len(node.inputs) == 2:
            data, axes = node.inputs
            const = next((n for n in graph.nodes if n.op == "Constant" and n.outputs[0] is axes), None)
            if const and hasattr(const.attrs["value"], "values"):
                node.inputs = [data]
                node.attrs["axes"] = const.attrs["value"].values.tolist()
                node.attrs["keepdims"] = 1


def _share_inputs_and_rename(g1: gs.Graph, g2: gs.Graph, prefix: str = "m2_"):
    img = g1.inputs[0]
    demo = g1.inputs[1]
    img.name = "image"
    demo.name = "demographics"

    old_img, old_demo = g2.inputs[0], g2.inputs[1]
    skip = {id(old_img), id(old_demo)}
    for v in g2.tensors().values():
        if id(v) not in skip and v.name:
            v.name = prefix + v.name
    for n in g2.nodes:
        if n.name:
            n.name = prefix + n.name

    for n in g2.nodes:
        for i, inp in enumerate(n.inputs):
            if inp is old_img:
                n.inputs[i] = img
            if inp is old_demo:
                n.inputs[i] = demo
    g2.inputs = [img, demo]
    return img, demo


def create_combined_onnx(
    model_path1: str,
    model_path2: str,
    whitelist1: List[int],
    whitelist2: List[int],
    scale_factor: float = 3.0,
    output_path: str = "combined.onnx",
) -> onnx.ModelProto:
    # Load & prepare graphs
    g1 = gs.import_onnx(onnx.load(model_path1))
    g2 = gs.import_onnx(onnx.load(model_path2))
    _fix_reduction_nodes(g1)
    _fix_reduction_nodes(g2)
    img_input, demo_input = _share_inputs_and_rename(g1, g2, "m2_")

    logits1 = g1.outputs[0]
    logits2 = g2.outputs[0]
    logits1.name = "logits_m1"
    logits2.name = "logits_m2"

    num_classes = logits1.shape[1] if logits1.shape and len(logits1.shape) == 2 else 11
    batch_shape = (None, num_classes)

    # Constants
    w1 = gs.Constant("w1", np.array(whitelist1, np.int64))
    w2 = gs.Constant("w2", np.array(whitelist2, np.int64))
    zero = gs.Constant("zero", np.array(0.0, np.float32))
    two = gs.Constant("two", np.array(2.0, np.float32))
    scale = gs.Constant("scale", np.array(scale_factor, np.float32))
    axis1 = gs.Constant("axis1", np.array([1], np.int64))

    # Helper: membership test that works with ReduceSum
    def membership_nodes(class_unsq: gs.Variable, whitelist_const: gs.Constant, prefix: str):
        eq = gs.Variable(f"{prefix}_eq", dtype=np.bool)
        cast = gs.Variable(f"{prefix}_cast", dtype=np.float32)
        reduced = gs.Variable(f"{prefix}_reduced", dtype=np.float32, shape=["batch", 1])
        is_member = gs.Variable(f"{prefix}_member", dtype=np.bool, shape=["batch", 1])

        nodes = [
            gs.Node(op="Equal", name=f"eq_{prefix}", inputs=[class_unsq, whitelist_const], outputs=[eq]),
            gs.Node(op="Cast", name=f"cast_{prefix}", inputs=[eq], outputs=[cast],
                    attrs={"to": onnx.TensorProto.FLOAT}),
            gs.Node(op="ReduceSum", name=f"red_{prefix}", inputs=[cast, axis1],
                    outputs=[reduced], attrs={"keepdims": 1}),
            gs.Node(op="Greater", name=f"gt_{prefix}", inputs=[reduced, zero], outputs=[is_member]),
        ]
        return nodes, is_member

    # ArgMax + Unsqueeze
    c1 = gs.Variable("c1", np.int64, ["batch"])
    c2 = gs.Variable("c2", np.int64, ["batch"])
    c1_u = gs.Variable("c1_u", np.int64, ["batch", 1])
    c2_u = gs.Variable("c2_u", np.int64, ["batch", 1])

    nodes = [
        gs.Node(op="ArgMax", name="argmax1", inputs=[logits1], outputs=[c1], attrs={"axis": 1, "keepdims": 0}),
        gs.Node(op="ArgMax", name="argmax2", inputs=[logits2], outputs=[c2], attrs={"axis": 1, "keepdims": 0}),
        gs.Node(op="Unsqueeze", name="unsq1", inputs=[c1, axis1], outputs=[c1_u]),
        gs.Node(op="Unsqueeze", name="unsq2", inputs=[c2, axis1], outputs=[c2_u]),
    ]

    # Membership checks
    n1, m1_in_w1 = membership_nodes(c1_u, w1, "m1_w1")
    n2, m1_in_w2 = membership_nodes(c1_u, w2, "m1_w2")
    n3, m2_in_w1 = membership_nodes(c2_u, w1, "m2_w1")
    n4, m2_in_w2 = membership_nodes(c2_u, w2, "m2_w2")
    nodes += n1 + n2 + n3 + n4

    # Conditions
    m1_not_w2 = gs.Variable("m1_not_w2", np.bool, ["batch", 1])
    m2_not_w1 = gs.Variable("m2_not_w1", np.bool, ["batch", 1])
    use_m1 = gs.Variable("use_model1", np.bool, ["batch", 1])
    use_m2 = gs.Variable("use_model2", np.bool, ["batch", 1])

    nodes += [
        gs.Node(op="Not", name="not1", inputs=[m1_in_w2], outputs=[m1_not_w2]),
        gs.Node(op="Not", name="not2", inputs=[m2_in_w1], outputs=[m2_not_w1]),
        gs.Node(op="And", name="and_m1", inputs=[m1_in_w1, m1_not_w2], outputs=[use_m1]),
        gs.Node(op="And", name="and_m2", inputs=[m2_not_w1, m2_in_w2], outputs=[use_m2]),
    ]

    # Expand masks
    shape_var = gs.Variable("shape", np.int64, [2])
    use_m1_exp = gs.Variable("use_m1_exp", np.bool, batch_shape)
    use_m2_exp = gs.Variable("use_m2_exp", np.bool, batch_shape)

    nodes += [
        gs.Node(op="Shape", name="shape", inputs=[logits1], outputs=[shape_var]),
        gs.Node(op="Expand", name="exp_m1", inputs=[use_m1, shape_var], outputs=[use_m1_exp]),
        gs.Node(op="Expand", name="exp_m2", inputs=[use_m2, shape_var], outputs=[use_m2_exp]),
    ]

    # Average fallback
    sum_ab = gs.Variable("sum", np.float32, batch_shape)
    avg = gs.Variable("avg", np.float32, batch_shape)
    temp = gs.Variable("temp", np.float32, batch_shape)
    selected = gs.Variable("selected", np.float32, batch_shape)
    scaled = gs.Variable("scaled", np.float32, batch_shape)
    probs = gs.Variable("probabilities", np.float32, batch_shape)

    nodes += [
        gs.Node(op="Add", name="add", inputs=[logits1, logits2], outputs=[sum_ab]),
        gs.Node(op="Div", name="div", inputs=[sum_ab, two], outputs=[avg]),
        gs.Node(op="Where", name="where_m2", inputs=[use_m2_exp, logits2, avg], outputs=[temp]),
        gs.Node(op="Where", name="where_final", inputs=[use_m1_exp, logits1, temp], outputs=[selected]),
        gs.Node(op="Mul", name="mul", inputs=[selected, scale], outputs=[scaled]),
        gs.Node(op="Softmax", name="softmax", inputs=[scaled], outputs=[probs], attrs={"axis": 1}),
    ]

    # Final graph
    graph = gs.Graph(nodes=g1.nodes + g2.nodes + nodes,
                     inputs=[img_input, demo_input],
                     outputs=[probs],
                     opset=17)

    model = gs.export_onnx(graph.cleanup().toposort())
    model = shape_inference.infer_shapes(model)
    onnx.checker.check_model(model, full_check=True)   # now passes!
    onnx.save(model, output_path)
    print(f"Combined model saved: {output_path}")
    # return model


# === RUN ===
create_combined_onnx(
    model_path1="../../models/2025-11-27/speechmaster/62_model94.onnx",
    model_path2="../../models/2025-11-27/grose/61_model08.onnx",
    whitelist1=[0, 3, 10, 9, 6, 7],
    whitelist2=[1, 2, 4, 5, 8],
    scale_factor=3.0,
    output_path="../../models/combine/2025-11-27/62_61_exclusive.onnx"
)

Combined model saved: ../../models/combine/2025-11-27/62_61_exclusive.onnx


### Down version

In [None]:
import onnx

# Load the original model
model = onnx.load(f"../../models/combine/2025-11-27/18vs122_1.onnx")

# Check original details (optional: for debugging)
print("Original IR version:", model.ir_version)
print("Original opset versions:", [(imp.domain, imp.version) for imp in model.opset_import])

# Downgrade IR version to 11 (your runtime's max)
model.ir_version = 10

# Save the downgraded model
downgraded_path = f"../../models/combine/2025-11-27/18vs122_1_down.onnx"
onnx.save(model, downgraded_path)
print(f"Downgraded model saved to: {downgraded_path}")

In [None]:
import numpy
import onnx
from onnx.helper import (
    make_node, make_graph, make_model, make_tensor_value_info)
from onnx.numpy_helper import from_array
from onnx.checker import check_model
from onnxruntime import InferenceSession

# initializers
value = numpy.array([0], dtype=numpy.float32)
zero = from_array(value, name='zero')

# Same as before, X is the input, Y is the output.
X = make_tensor_value_info('X', onnx.TensorProto.FLOAT, [None, None])
Y = make_tensor_value_info('Y', onnx.TensorProto.FLOAT, [None])

# The node building the condition. The first one
# sum over all axes.
rsum = make_node('ReduceSum', ['X'], ['rsum'])
# The second compares the result to 0.
cond = make_node('Greater', ['rsum', 'zero'], ['cond'])

# Builds the graph is the condition is True.
# Input for then
then_out = make_tensor_value_info(
    'then_out', onnx.TensorProto.FLOAT, None)
# The constant to return.
then_cst = from_array(numpy.array([1]).astype(numpy.float32))

# The only node.
then_const_node = make_node(
    'Constant', inputs=[],
    outputs=['then_out'],
    value=then_cst, name='cst1')

# And the graph wrapping these elements.
then_body = make_graph(
    [then_const_node], 'then_body', [], [then_out])

# Same process for the else branch.
else_out = make_tensor_value_info(
    'else_out', onnx.TensorProto.FLOAT, [5])
else_cst = from_array(numpy.array([-1]).astype(numpy.float32))

else_const_node = make_node(
    'Constant', inputs=[],
    outputs=['else_out'],
    value=else_cst, name='cst2')

else_body = make_graph(
    [else_const_node], 'else_body',
    [], [else_out])

# Finally the node If taking both graphs as attributes.
if_node = onnx.helper.make_node(
    'If', ['cond'], ['Y'],
    then_branch=then_body,
    else_branch=else_body)

# The final graph.
graph = make_graph([rsum, cond, if_node], 'if', [X], [Y], [zero])
onnx_model = make_model(graph)
check_model(onnx_model)

# Let's freeze the opset.
del onnx_model.opset_import[:]
opset = onnx_model.opset_import.add()
opset.domain = ''
opset.version = 15
onnx_model.ir_version = 8

# Save.
with open("onnx_if_sign.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

# Let's see the output.
sess = InferenceSession(onnx_model.SerializeToString(),
                        providers=["CPUExecutionProvider"])

x = numpy.ones((3, 2), dtype=numpy.float32)
res = sess.run(None, {'X': x})

# It works.
print("result", res)
print()

# Some display.
print(onnx_model)