# Segment Anything 추론 예제 코드
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1VWB53jTGn50oTTHu3mWkBnDGxpJj8CAA/view?usp=sharing) [![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/kim-jeonghyun/advanced_detection_segmentation_model)


작성자 : 김정현(kimjeonghyun.jkim@gmail.com)  
작성일 : 2023.10.24  
참고 : Segment Anything 공식 레포(https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb)  

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import requests
import zipfile
import matplotlib.pyplot as plt
import cv2
import os
import glob

## Step 0. Setup 확인

In [None]:
%pwd

'/content/drive/MyDrive/Colab Notebooks/Segment-Anything-A-Foundation-Model-for-Image-Segmentation'

In [None]:
cd /content/drive/MyDrive/Colab Notebooks/Segment-Anything-A-Foundation-Model-for-Image-Segmentation

/content/drive/MyDrive/Colab Notebooks/Segment-Anything-A-Foundation-Model-for-Image-Segmentation


In [None]:
!python --version

In [None]:
!cat /etc/issue

In [None]:
import torch, torchvision

print(torch.__version__, torch.cuda.is_available())

In [None]:
np.__version__

## Step 1. Install Segment Anything

In [None]:
pip install git+https://github.com/facebookresearch/segment-anything.git

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-8xg1u__p
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-8xg1u__p
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment-anything
  Building wheel for segment-anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment-anything: filename=segment_anything-1.0-py3-none-any.whl size=36586 sha256=c8eb0708357c4d8777b2fc2fa48abb108df52d48528bb2e671f0b466e73c9784
  Stored in directory: /tmp/pip-ephem-wheel-cache-yx93w1hx/wheels/10/cf/59/9ccb2f0a1bcc81d4fbd0e501680b5d088d690c6cfbc02dc99d
Successfully built segment-anything
Installing collected packages: segment-anything
Successfully 

## Step 2. Download sample Images

샘플 이미지가 아닌 내가 가진 이미지를 사용하면 됨

In [None]:
# def download_file(url, save_name):
#     url = url
#     if not os.path.exists(save_name):
#         file = requests.get(url)
#         open(save_name, 'wb').write(file.content)

# download_file(
#     'https://www.dropbox.com/s/0etn81u50kfs2ah/input.zip?dl=1',
#     'input.zip'
# )

In [None]:
# # Unzip the data file
# def unzip(zip_file=None):
#     try:
#         with zipfile.ZipFile(zip_file) as z:
#             z.extractall("./")
#             print("Extracted all")
#     except:
#         print("Invalid file")

# unzip('input.zip')

Extracted all


## Step 3. Download the Models

현재 facebook에서 제공하고 있는 pretrained_weights 는 3가지
base 350mb, large 1.2gb, huge 2.5gb이므로 적절하게 선택하여 다운받을 것  
크기가 크므로 매번 다운 받을 필요는 없고 추론시 모델 위치 경로 설정만 잘해주면 됨


In [None]:
# !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -O sam_vit_h.pth

# !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -O sam_vit_b.pth

# !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -O sam_vit_l.pth

--2023-04-25 13:16:12--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 52.84.162.20, 52.84.162.51, 52.84.162.119, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|52.84.162.20|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2564550879 (2.4G) [binary/octet-stream]
Saving to: ‘sam_vit_h.pth’


2023-04-25 13:16:24 (208 MB/s) - ‘sam_vit_h.pth’ saved [2564550879/2564550879]

--2023-04-25 13:16:25--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 52.84.162.119, 52.84.162.51, 52.84.162.103, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|52.84.162.119|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 375042383 (358M) [binary/octet-stream]
Saving to: ‘sam_vit_b.pth’


2023-04-25 13:16:27 (229 MB/s) - ‘sam_vit_b.pth’ saved [375042383/375042383]

-

## Step 4. Inference

In [None]:
%%writefile segment.py
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import torch

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

parser = argparse.ArgumentParser()
parser.add_argument(
    '--input',
    default='input/image_4.jpg'
)
args = parser.parse_args()

if not os.path.exists('outputs'):
    os.makedirs('outputs')

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        np.dstack((img, m*0.35))
        ax.imshow(np.dstack((img, m*0.35)))

device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h.pth")
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(sam)

image_path = args.input
image_name = image_path.split(os.path.sep)[-1]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
plt.figure(figsize=(12, 9))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.savefig(os.path.join('outputs', image_name), bbox_inches='tight')

Writing segment.py


In [None]:
# !python segment.py --input input/image_1.jpg
# !python segment.py --input input/image_2.jpg
# !python segment.py --input input/image_3.jpg
# !python segment.py --input input/image_4.jpg

'input' 폴더에 있는 이미지들 한꺼번에 처리하기

In [None]:
!python segment.py --input input/B546_FMS_8_0_1_frame200.jpg

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import torch

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

In [None]:
%pwd

'/content/drive/MyDrive/Colab Notebooks/Segment-Anything-A-Foundation-Model-for-Image-Segmentation'

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        np.dstack((img, m*0.35))
        ax.imshow(np.dstack((img, m*0.35)))

device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(sam)
