In [5]:
import onnxruntime as ort
import numpy as np
import cv2
from PIL import Image

def preprocess_image(image):
    # Resize to the model's expected input size (1280x720)
    image = cv2.resize(image, (1280, 720))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = np.array(image).transpose(2, 0, 1).astype(np.float32) / 255.0  # CHW format
    image = np.expand_dims(image, axis=0)  # Add batch dimension
    return image

def postprocess_output(output, color_map):
    predicted = np.argmax(output, axis=1).squeeze(0)
    height, width = predicted.shape
    rgb_image = np.zeros((height, width, 3), dtype=np.uint8)

    # Map each class label to the corresponding RGB value
    for class_id, rgb_value in color_map.items():
        rgb_image[predicted == class_id] = rgb_value

    return rgb_image

def run_live_inference(model_path, color_map):
    # Load the ONNX model
    session = ort.InferenceSession(model_path)

    # Get the input and output names for the ONNX model
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    # Start video capture from the webcam
    cap = cv2.VideoCapture(1)  # Use 0 for the default webcam, or provide a file path for video

    while True:
        ret, frame = cap.read()

        if not ret:
            break

        # Preprocess the image
        image = preprocess_image(frame)

        # Perform inference
        outputs = session.run([output_name], {input_name: image})

        # Postprocess the output
        segmented_image = postprocess_output(outputs[0], color_map)

        # Display the segmented image in real-time
        segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)
        cv2.imshow('Live Segmentation', segmented_image)

        # Press 'q' to quit the video stream
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # Release the webcam and close windows
    cap.release()
    cv2.destroyAllWindows()

# Example usage
color_to_class = {
    0: (255, 197, 25),  # Forklift
    1: (140, 255, 25),  # Rack
    2: (140, 25, 255),  # Crate
    3: (226, 255, 25),  # Floor
    4: (255, 111, 25),  # Railing
    5: (255, 25, 197),  # Pallet
    6: (54, 255, 25),   # Stillage
    7: (25, 255, 82),   # iwhub
    8: (25, 82, 255),   # Dolly
    9: (0, 0, 0)        # Background
}

onnx_model_path = "onnx/deeplab_model.onnx"

run_live_inference(onnx_model_path, color_to_class)