In [None]:
import cv2
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import numpy as np

In [None]:
# Load image
image = cv2.imread('../data/images/greenspillexample.jpg')
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


In [None]:
# Create options for segmenter
model_path = "../models/selfie_multiclass_256x256.tflite"
base_options = python.BaseOptions(model_asset_path=model_path)
options = vision.ImageSegmenterOptions(base_options=base_options,
                                       output_category_mask=True,
                                       output_confidence_masks=False)


In [None]:
# Create segmenter
with vision.ImageSegmenter.create_from_options(options=options) as segmenter:
    # Load input image
    input_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_rgb)
    # Segment
    segmentation_result = segmenter.segment(input_image)
    # Access category mask
    category_mask = segmentation_result.category_mask.numpy_view()  # shape: (256, 256)


In [None]:
# Resize mask to og size
category_mask_resized = cv2.resize(category_mask,
                                   (image.shape[1], image.shape[0]),
                                   interpolation=cv2.INTER_NEAREST)


In [None]:
# Choose which classes to keep
target_classes = [2, 3]
mask = np.isin(category_mask_resized, target_classes)


In [None]:
# Apply mask to og image
mask_3ch = np.stack([mask]*3, axis=-1)
output_image = np.where(mask_3ch, image, 0)

# Save colour-segemented result
cv2.imwrite('../outputs/segmented_face_skin.png', output_image)
