In [2]:
import onnxruntime as ort
import numpy as np
from PIL import Image

In [5]:
def test_onnx_model_on_single_image(model_path, image_path, color_map, device='cpu'):
    # Load the ONNX model
    session = ort.InferenceSession(model_path)
    
    # Get the input and output names for the ONNX model
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = np.array(image).transpose(2, 0, 1).astype(np.float32) / 255.0  # CHW format
    image = np.expand_dims(image, axis=0)  # Add batch dimension

    # Perform inference
    outputs = session.run([output_name], {input_name: image})
    
    # Get the predicted class labels
    predicted = np.argmax(outputs[0], axis=1).squeeze(0)

    # Initialize an empty RGB image
    height, width = predicted.shape
    rgb_image = np.zeros((height, width, 3), dtype=np.uint8)

    # Map each class label to the corresponding RGB value
    for class_id, rgb_value in color_map.items():
        rgb_image[predicted == class_id] = rgb_value

    # Convert the numpy array to a PIL Image and return
    return Image.fromarray(rgb_image)

In [6]:
# Example usage:
color_to_class = {
    0: (255, 197, 25),  # Forklift
    1: (140, 255, 25),  # Rack
    2: (140, 25, 255),  # Crate
    3: (226, 255, 25),  # Floor
    4: (255, 111, 25),  # Railing
    5: (255, 25, 197),  # Pallet
    6: (54, 255, 25),   # Stillage
    7: (25, 255, 82),   # iwhub
    8: (25, 82, 255),   # Dolly
    9: (0, 0, 0)        # Background
}

In [7]:
onnx_model_path = "onnx/deeplab_model.onnx"
image_path = "rgb_0201.png"

output_image = test_onnx_model_on_single_image(onnx_model_path, image_path, color_to_class)
output_image.show()  # Display the image
output_image.save('C:\\Users\\georg\\Desktop\\test\\test12.png')  # Save the image