In [None]:
import matplotlib.pyplot as plt
import numpy as np
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision


In [None]:
BG_COLOUR = (0, 0, 0)       # Black
MASK_COLOUR = (255, 255, 255)   # White

# Create options for ImageSegmenter
base_options = python.BaseOptions(model_asset_path='../preprocessing/segmentation_models/MediaPipe/selfie_segmenter_landscape.tflite')
options = vision.ImageSegmenterOptions(base_options=base_options, output_category_mask=True)

# Create image segmenter
with vision.ImageSegmenter.create_from_options(options=options) as segmenter:
    # Create MediaPipe image file to be segmented
    image = mp.Image.create_from_file("../data/images/YiJumping/frame.001.png")

    # Retrieve the masks
    segmentation_result = segmenter.segment(image)
    category_mask = segmentation_result.category_mask

    # Generate solid colour images for showing the output segmentation mask
    image_data = image.numpy_view()
    fg_image = np.zeros(image_data.shape, np.uint8)
    fg_image[:] = MASK_COLOUR
    bg_image = np.zeros(image_data.shape, np.uint8)
    bg_image[:] = BG_COLOUR

    condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) > 0.2
    output_image = np.where(condition, fg_image, bg_image)

    print(f"Segmentation mask:")
    plt.imshow(output_image)
