# 배경 제거 Application

Colab 환경에서 배경 제거 애플리케이션을 만들어봅시다. 애플리케이션 사용자의 유스케이스는 아래와 같습니다.

- 사용자는 이미지 파일을 업로드할 수 있다.
- 사용자는 이미지에서 원하는 객체 클릭한다.
- 사용자는 배경 제거 이미지의 결과를 확인하고 다운로드 받을 수 있다.

## 패키지 및 예제 데이터 다운로드하기
python package들을 설치합니다. 예제로 사용할 이미지들도 다운로드 받습니다. Colab에서 실행하지 않는 경우 이 셀은 실행하지 않습니다.

In [2]:
!wget https://raw.githubusercontent.com/mentor1023/dl_apps/main/segmentation/requirements-colab.txt
!pip install -r requirements-colab.txt

!mkdir examples
!cd examples && wget https://github.com/mentor1023/dl_apps/raw/main/segmentation/examples/dog.jpg
!cd examples && wget https://github.com/mentor1023/dl_apps/raw/main/segmentation/examples/mannequin.jpg

--2024-12-02 12:32:55--  https://raw.githubusercontent.com/mentor1023/dl_apps/main/segmentation/requirements-colab.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 147 [text/plain]
Saving to: ‘requirements-colab.txt.1’


2024-12-02 12:32:55 (11.6 MB/s) - ‘requirements-colab.txt.1’ saved [147/147]

Collecting git+https://github.com/facebookresearch/segment-anything.git (from -r requirements-colab.txt (line 2))
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-zg99tbfa
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-zg99tbfa
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac0

## 패키지 불러오기

In [3]:
import os
import urllib
from typing import Tuple

import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry

## 애플리케이션 UI 구현하기

In [4]:
with gr.Blocks() as app:
    gr.Markdown("# Interactive Remove Background from Image")
    with gr.Row():
        coord_x = gr.Number(label="Mouse coords x")
        coord_y = gr.Number(label="Mouse coords y")

    with gr.Row():
        input_img = gr.Image(label="Input image", height=600)
        output_img = gr.Image(label="Output image", height=600)

In [5]:
app.launch(inline=False, share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://a6e35b165742fd33ec.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




In [6]:
app.close()

Closing server running on port: 7860


## 마우스 클릭 이벤트 구현하기

In [7]:
def get_coords(evt: gr.SelectData):
    return evt.index[0], evt.index[1]


with gr.Blocks() as app:
    gr.Markdown("# Interactive Remove Background from Image")
    with gr.Row():
        coord_x = gr.Number(label="Mouse coords x")
        coord_y = gr.Number(label="Mouse coords y")

    with gr.Row():
        input_img = gr.Image(label="Input image", height=600)
        output_img = gr.Image(label="Output image", height=600)

    input_img.select(get_coords, None, [coord_x, coord_y])

In [8]:
app.launch(inline=False, share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://45ee2dd3d281c24500.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




In [9]:
app.close()

Closing server running on port: 7860


## SAM 추론기 클래스 구현하기

In [10]:
CHECKPOINT_PATH = os.path.join("checkpoint")
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
class SAMInferencer:
    def __init__(
        self,
        checkpoint_path: str,
        checkpoint_name: str,
        checkpoint_url: str,
        model_type: str,
        device: torch.device,
    ):
        print("[INFO] Initailize inferencer")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path, exist_ok=True)
        checkpoint = os.path.join(checkpoint_path, checkpoint_name)
        if not os.path.exists(checkpoint):
            urllib.request.urlretrieve(checkpoint_url, checkpoint)
        sam = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
        self.predictor = SamPredictor(sam)

    def inference(
        self,
        image: np.ndarray,
        point_coords: np.ndarray,
        points_labels: np.ndarray,
    ) -> np.ndarray:
        self.predictor.set_image(image)
        masks, scores, _ = self.predictor.predict(point_coords, points_labels)
        mask, _ = self.select_mask(masks, scores)
        return mask

    def select_mask(
        self, masks: np.ndarray, scores: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        # Determine if we should return the multiclick mask or not from the number of points.
        # The reweighting is used to avoid control flow.
        # Reference: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/utils/onnx.py#L92-L105
        score_reweight = np.array([-1000] + [0] * 2)
        score = scores + score_reweight
        best_idx = np.argmax(score)
        selected_mask = np.expand_dims(masks[best_idx, :, :], axis=-1)
        selected_score = np.expand_dims(scores[best_idx], axis=0)
        return selected_mask, selected_score

In [12]:
inferencer = SAMInferencer(
    CHECKPOINT_PATH, CHECKPOINT_NAME, CHECKPOINT_URL, "vit_h", DEVICE
)

[INFO] Initailize inferencer


  state_dict = torch.load(f)


## 추론 및 배경 제거 후처리 구현하기

In [13]:
def extract_object(image: np.ndarray, point_x: int, point_y: int):
    point_coords = np.array([[point_x, point_y]])
    point_label = np.array([1])

    # Get mask
    mask = inferencer.inference(image, point_coords, point_label)

    # Extract object and remove background
    # Postprocess mask
    mask = (mask > 0).astype(np.uint8)

    # Remove background
    result_image = cv2.bitwise_and(image, image, mask=mask)

    # Convert to rgba channel
    bgr_channel = result_image[..., :3]  # BGR 채널 분리
    alpha_channel = np.where(bgr_channel[..., 0] == 0, 0, 255).astype(np.uint8)
    result_image = np.dstack((bgr_channel, alpha_channel))  # BGRA 이미지 생성

    return result_image


def extract_object_by_event(image: np.ndarray, evt: gr.SelectData):
    click_x, click_y = evt.index

    return extract_object(image, click_x, click_y)


def get_coords(evt: gr.SelectData):
    return evt.index[0], evt.index[1]

In [14]:
with gr.Blocks() as app:
    gr.Markdown("# Interactive Remove Background from Image")
    with gr.Row():
        coord_x = gr.Number(label="Mouse coords x")
        coord_y = gr.Number(label="Mouse coords y")

    with gr.Row():
        input_img = gr.Image(label="Input image", height=600)
        output_img = gr.Image(label="Output image", height=600)

    input_img.select(extract_object_by_event, [input_img], [output_img])
    input_img.select(get_coords, None, [coord_x, coord_y])

In [15]:
app.launch(inline=False, share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://435090d43e046ea831.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




In [16]:
app.close()

Closing server running on port: 7860


## 최종 App 구현하기

In [19]:
# Implement inferencer
CHECKPOINT_PATH = os.path.join("checkpoint")
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
MODEL_TYPE = "default"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class SAMInferencer:
    def __init__(
        self,
        checkpoint_path: str,
        checkpoint_name: str,
        checkpoint_url: str,
        model_type: str,
        device: torch.device,
    ):
        print("[INFO] Initailize inferencer")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path, exist_ok=True)
        checkpoint = os.path.join(checkpoint_path, checkpoint_name)
        if not os.path.exists(checkpoint):
            urllib.request.urlretrieve(checkpoint_url, checkpoint)
        sam = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
        self.predictor = SamPredictor(sam)

    def inference(
        self,
        image: np.ndarray,
        point_coords: np.ndarray,
        points_labels: np.ndarray,
    ) -> np.ndarray:
        self.predictor.set_image(image)
        masks, scores, _ = self.predictor.predict(point_coords, points_labels)
        mask, _ = self.select_mask(masks, scores)
        return mask

    def select_mask(
        self, masks: np.ndarray, scores: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        # Determine if we should return the multiclick mask or not from the number of points.
        # The reweighting is used to avoid control flow.
        # Reference: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/utils/onnx.py#L92-L105
        score_reweight = np.array([-1000] + [0] * 2)
        score = scores + score_reweight
        best_idx = np.argmax(score)
        selected_mask = np.expand_dims(masks[best_idx, :, :], axis=-1)
        selected_score = np.expand_dims(scores[best_idx], axis=0)
        return selected_mask, selected_score


inferencer = SAMInferencer(
    CHECKPOINT_PATH, CHECKPOINT_NAME, CHECKPOINT_URL, MODEL_TYPE, DEVICE
)

# Implement event function
def extract_object(image: np.ndarray, point_x: int, point_y: int):
    point_coords = np.array([[point_x, point_y], [0, 0]])
    point_label = np.array([1, -1])

    # Get mask
    mask = inferencer.inference(image, point_coords, point_label)

    # Extract object and remove background
    # Postprocess mask
    mask = (mask > 0).astype(np.uint8)

    # Remove background
    result_image = cv2.bitwise_and(image, image, mask=mask)

    # Convert to rgba channel
    bgr_channel = result_image[..., :3]  # BGR 채널 분리
    alpha_channel = np.where(bgr_channel[..., 0] == 0, 0, 255).astype(np.uint8)
    result_image = np.dstack((bgr_channel, alpha_channel))  # BGRA 이미지 생성

    return result_image


def extract_object_by_event(image: np.ndarray, evt: gr.SelectData):
    click_x, click_y = evt.index

    return extract_object(image, click_x, click_y)


def get_coords(evt: gr.SelectData):
    return evt.index[0], evt.index[1]


# Implement app
with gr.Blocks() as app:
    gr.Markdown("# Interactive Remove Background from Image")
    with gr.Row():
        coord_x = gr.Number(label="Mouse coords x")
        coord_y = gr.Number(label="Mouse coords y")

    with gr.Row():
        input_img = gr.Image(label="Input image", height=600)
        output_img = gr.Image(label="Output image", height=600)

    input_img.select(get_coords, None, [coord_x, coord_y])
    input_img.select(extract_object_by_event, [input_img], output_img)

    gr.Markdown("## Image Examples")
    gr.Examples(
        examples=[
            [os.path.join(os.getcwd(), "examples/dog.jpg"), 1013, 786],
            [os.path.join(os.getcwd(), "examples/mannequin.jpg"), 1720, 230],
        ],
        inputs=[input_img, coord_x, coord_y],
        outputs=output_img,
        fn=extract_object,
        run_on_click=True,
    )

app.launch(inline=False, share=True)

[INFO] Initailize inferencer


  state_dict = torch.load(f)


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
IMPORTANT: You are using gradio version 3.40.0, however version 4.44.1 is available, please upgrade.
--------
Running on public URL: https://98c5867d2d7f9c8596.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




In [20]:
app.close()

Closing server running on port: 7860
