# interactive segmentation

https://developers.google.com/mediapipe/solutions/vision/interactive_segmenter

In [1]:
# !wget -O ./models/model.tflite -q https://storage.googleapis.com/mediapipe-models/interactive_segmenter/magic_touch/float32/1/magic_touch.tflite

In [2]:
import cv2
import numpy as np
import mediapipe as mp

from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.tasks.python.components import containers

In [3]:
import math

def _normalized_to_pixel_coordinates(
    normalized_x: float, normalized_y: float, image_width: int,
    image_height: int):
    """Converts normalized value pair to pixel coordinates."""

    # Checks if the float value is between 0 and 1.
    def is_valid_normalized_value(value: float) -> bool:
        return (value > 0 or math.isclose(0, value)) and (value < 1 or
                                                        math.isclose(1, value))

    if not (is_valid_normalized_value(normalized_x) and
            is_valid_normalized_value(normalized_y)):
        # TODO: Draw coordinates even if it's outside of the image bounds.
        return None
    x_px = min(math.floor(normalized_x * image_width), image_width - 1)
    y_px = min(math.floor(normalized_y * image_height), image_height - 1)
    return x_px, y_px



In [4]:
BG_COLOR = (192, 192, 192) # gray
MASK_COLOR = (255, 255, 255) # white
x = 0.68
y = 0.68

RegionOfInterest = vision.InteractiveSegmenterRegionOfInterest
NormalizedKeypoint = containers.keypoint.NormalizedKeypoint

model_path = './models/model.tflite'
# Create the options that will be used for InteractiveSegmenter
base_options = python.BaseOptions(model_asset_path=model_path)
options = vision.ImageSegmenterOptions(base_options=base_options,
                                        output_category_mask=True)

IMAGE_FILENAMES = ['./../../../../test_img/f3978ebc-9030-49e7-aa5e-4a370a955e1b.jpg',
                    './../../../../test_img/download.jpg']

In [6]:
with vision.InteractiveSegmenter.create_from_options(options) as segmenter:
    for image_file_name in IMAGE_FILENAMES:
        # ロード
        image = mp.Image.create_from_file(image_file_name)
        # 予測
        roi = RegionOfInterest(format=RegionOfInterest.Format.KEYPOINT,
                            keypoint=NormalizedKeypoint(x, y))
        segmentation_result = segmenter.segment(image, roi)
        category_mask = segmentation_result.category_mask
        
        # マスク作成
        image_data = image.numpy_view()
        fg_image = np.zeros(image_data.shape, dtype=np.uint8)
        fg_image[:] = MASK_COLOR
        bg_image = np.zeros(image_data.shape, dtype=np.uint8)
        bg_image[:] = BG_COLOR
        # 適用
        condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) > 0.1
        output_image = np.where(condition, fg_image, bg_image)
        
        # Draw a white dot with black border to denote the point of interest
        thickness, radius = 6, -1
        keypoint_px = _normalized_to_pixel_coordinates(x, y, image.width, image.height)
        cv2.circle(output_image, keypoint_px, thickness + 5, (0, 0, 0), radius)
        cv2.circle(output_image, keypoint_px, thickness, (255, 255, 255), radius)

        print(f'Segmentation mask of {image_file_name}:')
        cv2.imshow(image_file_name, output_image)
        cv2.waitKey(3000)
cv2.destroyAllWindows()

Segmentation mask of ./../../../../test_img/f3978ebc-9030-49e7-aa5e-4a370a955e1b.jpg:
Segmentation mask of ./../../../../test_img/download.jpg:


### 対象物以外のモザイクはsegmentationと同じ

## 対象を囲む

In [None]:
x = 0.68
y = 0.68

RegionOfInterest = vision.InteractiveSegmenterRegionOfInterest
NormalizedKeypoint = containers.keypoint.NormalizedKeypoint

model_path = './models/model.tflite'
# Create the options that will be used for InteractiveSegmenter
base_options = python.BaseOptions(model_asset_path=model_path)
options = vision.ImageSegmenterOptions(base_options=base_options,
                                        output_category_mask=True)

IMAGE_FILENAMES = ['./../../../../test_img/f3978ebc-9030-49e7-aa5e-4a370a955e1b.jpg',
                    './../../../../test_img/download.jpg']

In [8]:
OVERLAY_COLOR = (100, 100, 0)
with python.vision.InteractiveSegmenter.create_from_options(options) as segmenter:
    for image_file_name in IMAGE_FILENAMES:
        image = mp.Image.create_from_file(image_file_name)

        roi = RegionOfInterest(format=RegionOfInterest.Format.KEYPOINT,
                            keypoint=NormalizedKeypoint(x, y))
        segmentation_result = segmenter.segment(image, roi)
        category_mask = segmentation_result.category_mask

        image_data = cv2.cvtColor(image.numpy_view(), cv2.COLOR_BGR2RGB)

        overlay_image = np.zeros(image_data.shape, dtype=np.uint8)
        overlay_image[:] = OVERLAY_COLOR

        alpha = np.stack((category_mask.numpy_view(),) * 3, axis=-1) > 0.1
        # 透明度、塗りつぶし（boolをfloatにしてTrue=1,False=0にして0.7を書ける)
        alpha = alpha.astype(float) * 0.7

        # 透明度を配慮した合成画像になる
        output_image = image_data * (1 - alpha) + overlay_image * alpha
        output_image = output_image.astype(np.uint8)

        # 点の描画
        thickness, radius = 6, -1
        keypoint_px = _normalized_to_pixel_coordinates(x, y, image.width, image.height)
        cv2.circle(output_image, keypoint_px, thickness + 5, (0, 0, 0), radius)
        cv2.circle(output_image, keypoint_px, thickness, (255, 255, 255), radius)

        print(f'{image_file_name}:')
        cv2.imshow(image_file_name, output_image)
        cv2.waitKey(3000)
    cv2.destroyAllWindows()

./../../../../test_img/f3978ebc-9030-49e7-aa5e-4a370a955e1b.jpg:
./../../../../test_img/download.jpg:
