# Segmentation

https://developers.google.com/mediapipe/solutions/vision/image_segmenter

In [1]:
# !wget -O ./models/deeplabv3.tflite -q https://storage.googleapis.com/mediapipe-models/image_segmenter/deeplab_v3/float32/1/deeplab_v3.tflite

In [1]:
import mediapipe as mp
import numpy as np
import math
import cv2

In [2]:
model_path = './models/deeplabv3.tflite'

## Image

In [8]:
IMAGE_FILENAMES = ['./../../../../test_img/f3978ebc-9030-49e7-aa5e-4a370a955e1b.jpg',
                   './../../../../test_img/download.jpg']

In [5]:
BaseOptions = mp.tasks.BaseOptions
ImageSegmenter = mp.tasks.vision.ImageSegmenter
ImageSegmenterOptions = mp.tasks.vision.ImageSegmenterOptions
VisionRunningMode = mp.tasks.vision.RunningMode


options = ImageSegmenterOptions(
    base_options=BaseOptions(model_asset_path=model_path),
    running_mode=VisionRunningMode.IMAGE,
    output_category_mask=True)

In [6]:
DESIRED_HEIGHT = 480
DESIRED_WIDTH = 480

def resize_and_show(image, file_name):
    h, w = image.shape[:2]
    if h < w:
        img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))
    else:
        img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))
    cv2.imshow(f'Segmentation mask of {file_name}:', img)
    cv2.waitKey(3000)

In [12]:
# default
BG_COLOR = (192, 192, 192) # gray
MASK_COLOR = (255, 255, 255) # white

with ImageSegmenter.create_from_options(options) as segmenter:
    for file_name in IMAGE_FILENAMES:
        mp_image = mp.Image.create_from_file(file_name)
        segmented_masks = segmenter.segment(mp_image)
        category_mask = segmented_masks.category_mask

        image_data = mp_image.numpy_view()
        fg_image = np.zeros(image_data.shape, dtype=np.uint8)
        fg_image[:] = MASK_COLOR
        bg_image = np.zeros(image_data.shape, dtype=np.uint8)
        bg_image[:] = BG_COLOR

        # 信頼度以上かどうかをboolで表す
        condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) > 0.5
        # true=fg_image, false=bg_image
        output_image = np.where(condition, fg_image, bg_image)

        print(f'Segmentation mask of {file_name}:')
        resize_and_show(output_image, file_name)
        
cv2.destroyAllWindows()

(1200, 1200)
Segmentation mask of ./../../../../test_img/f3978ebc-9030-49e7-aa5e-4a370a955e1b.jpg:
(189, 267)
Segmentation mask of ./../../../../test_img/download.jpg:


In [13]:
# application(対象以外をぼかす)

with ImageSegmenter.create_from_options(options) as segmenter:
    for file_name in IMAGE_FILENAMES:
        mp_image = mp.Image.create_from_file(file_name)
        segmented_masks = segmenter.segment(mp_image)
        category_mask = segmented_masks.category_mask

        # true
        image_data = cv2.cvtColor(mp_image.numpy_view(), cv2.COLOR_BGR2RGB)
        # false(background)
        blurred_image = cv2.GaussianBlur(image_data, (55,55), 0)
        # boolマスク作成
        condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) > 0.1
        output_image = np.where(condition, image_data, blurred_image)
        
        print(f'Blurred background of {file_name}:')
        resize_and_show(output_image, file_name)
cv2.destroyAllWindows()

Blurred background of ./../../../../test_img/f3978ebc-9030-49e7-aa5e-4a370a955e1b.jpg:
Blurred background of ./../../../../test_img/download.jpg:


## Video

In [9]:
BaseOptions = mp.tasks.BaseOptions
ImageSegmenter = mp.tasks.vision.ImageSegmenter
ImageSegmenterOptions = mp.tasks.vision.ImageSegmenterOptions
VisionRunningMode = mp.tasks.vision.RunningMode

options = ImageSegmenterOptions(
    base_options=BaseOptions(model_asset_path=model_path),
    running_mode=VisionRunningMode.VIDEO,
    output_category_mask=True)


In [11]:
cap = cv2.VideoCapture('./../../../../test_video/2795691-hd_1920_1080_25fps.mp4')
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_vid = cv2.VideoWriter('./../segmentation_output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

In [13]:
with ImageSegmenter.create_from_options(options) as segmenter:
    frame_index = 1
    while True:
        ret, frame = cap.read()
        if ret:
            numpy_frame_from_opencv = np.asarray(frame)
            mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=numpy_frame_from_opencv)
            frame_timestamp_ms = int(1000 * frame_index / fps)
            segmented_masks = segmenter.segment_for_video(mp_image, frame_timestamp_ms)
            category_mask = segmented_masks.category_mask
            
            # true
            blurred_image = cv2.GaussianBlur(numpy_frame_from_opencv, (55,55), 0)
            # boolマスク作成
            condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) > 0.1
            # 組み合わせ
            output_image = np.where(condition, numpy_frame_from_opencv, blurred_image)
            # 書き込み
            output_vid.write(output_image)
            frame_index += 1
        else:
            # VideoWriterを解放
            output_vid.release()
            break

## camera
image推論を連続で使用

In [15]:
BaseOptions = mp.tasks.BaseOptions
ImageSegmenter = mp.tasks.vision.ImageSegmenter
ImageSegmenterOptions = mp.tasks.vision.ImageSegmenterOptions
VisionRunningMode = mp.tasks.vision.RunningMode

# Create a image segmenter instance with the image mode:
options = ImageSegmenterOptions(
    base_options=BaseOptions(model_asset_path=model_path),
    running_mode=VisionRunningMode.IMAGE,
    output_category_mask=True)

In [17]:
with ImageSegmenter.create_from_options(options) as segmenter:
    cap = cv2.VideoCapture(0)
    while True:
        ret, frame = cap.read()
        
        if ret:
            numpy_frame_from_opencv = np.asarray(frame)
            mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=numpy_frame_from_opencv)
            segmented_masks = segmenter.segment(mp_image)
            category_mask = segmented_masks.category_mask

            # false(background)
            blurred_image = cv2.GaussianBlur(numpy_frame_from_opencv, (55,55), 0)
            # boolマスク作成
            condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) > 0.1
            output_image = np.where(condition, numpy_frame_from_opencv, blurred_image)

            cv2.imshow('camera' , output_image)

            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        else:
            break
    cap.release()
    cv2.destroyAllWindows()