In [2]:
!pip install -q mediapipe

In [1]:
import urllib

IMAGE_FILENAMES = ['segmentation_input_rotation0.jpg']

for name in IMAGE_FILENAMES:
  url = f'https://storage.googleapis.com/mediapipe-assets/{name}'
  urllib.request.urlretrieve(url, name)

In [None]:
import numpy as np
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import cv2  # For displaying images

# Define distinct colors for each class (3-channel RGB)
MASK_COLORS = [
    (255, 0, 0),    # Red for class 0
    (0, 255, 0),    # Green for class 1
    (0, 0, 255),    # Blue for class 2
    (255, 255, 0),  # Yellow for class 3
    (255, 0, 255),  # Magenta for class 4
    (0, 255, 255)   # Cyan for class 5
]

# Background color for areas without a class (3-channel RGB)
BG_COLOR = (192, 192, 192)  # Gray

# Create the options for the ImageSegmenter
base_options = python.BaseOptions(model_asset_path='web_development/model/selfie_multiclass_256x256.tflite')
options = vision.ImageSegmenterOptions(base_options=base_options,
                                       output_category_mask=True)

# Create the image segmenter
with vision.ImageSegmenter.create_from_options(options) as segmenter:

    # Loop through demo image(s)
    for image_file_name in IMAGE_FILENAMES:

        # Create the MediaPipe image file that will be segmented
        image = mp.Image.create_from_file(image_file_name)

        # Retrieve the masks for the segmented image
        segmentation_result = segmenter.segment(image)
        category_mask = segmentation_result.category_mask.numpy_view()

        # Convert the original image to RGB for display
        original_image = image.numpy_view()
        if original_image.shape[-1] == 4:  # If RGBA, convert to RGB
            original_image = cv2.cvtColor(original_image, cv2.COLOR_RGBA2RGB)

        output_image = np.zeros_like(original_image, dtype=np.uint8)
        output_image[:] = BG_COLOR  # Initialize with background color

        # Overlay each class with its corresponding color
        for class_id in range(len(MASK_COLORS)):
            mask_color = np.array(MASK_COLORS[class_id], dtype=np.uint8)
            class_mask = category_mask == class_id

            # Create an overlay with the specific class color
            overlay = np.zeros_like(output_image, dtype=np.uint8)
            overlay[class_mask] = mask_color

            # Blend the overlay with the output image
            output_image = cv2.addWeighted(output_image, 1.0, overlay, 0.5, 0)

        # Display the result
        print(f"Unique classes in {image_file_name}: {np.unique(category_mask)}")
        cv2.imshow(f"Segmented Image with Highlighted Classes - {image_file_name}", output_image)
        cv2.waitKey(0)  # Wait for a key press before showing the next image
        cv2.destroyAllWindows()

I0000 00:00:1735057377.216500  715468 gl_context.cc:357] GL version: 2.1 (2.1 Metal - 89.3), renderer: Apple M1
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W0000 00:00:1735057377.249323  715763 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
  graph_config = self._runner.get_graph_config()


Unique classes in segmentation_input_rotation0.jpg: [0 1 2 3 4 5]
