Segment-Anything

Select objects specified by input points or box or both using Segmenting Anything Library.
https://github.com/facebookresearch/segment-anything
Output numpy datafiles in specified output directory.


In [1]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python
!pip install torch
!pip install torchvision

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to c:\users\colin\appdata\local\temp\pip-req-build-nlkrrkpc
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'


  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git 'C:\Users\colin\AppData\Local\Temp\pip-req-build-nlkrrkpc'




In [2]:
import logging
import os
import sys
import re
import cv2
import numpy as np
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator

In [3]:
# model type
model_type = os.environ.get("model_type", "vit_h")  # default, vit_h, vit_l or vit_b

# checkpoint
# different model type requires different check point
checkpoint_path = os.environ.get("checkpoint_path")

# input image
input_image_path = os.environ.get("input_image_path")

# input points
# input either input points or box or both to specify the selected object
# format:
# x,y,x,y,x,y... for points. Multiple allowed
input_points = os.environ.get("input_array", "None")  # e.g. 100,200

# input box
# format:
# x,y,x,y for box. Only one box allowed
input_box = os.environ.get("input_box", "None")  # e.g. 100,200,300,400

# temporal data storage for local execution
data_dir = os.environ.get('data_dir', '../../data/')

# dummy_output (to be fixed once C3 supports < 1 outputs)
# output_dummy = os.environ.get('output_dummy', 'output_dummy')

In [4]:
parameters = list(
    map(lambda s: re.sub('$', '"', s),
        map(
            lambda s: s.replace('=', '="'),
            filter(
                lambda s: s.find('=') > -1 and bool(re.match(r'[A-Za-z0-9_]*=[.\/A-Za-z0-9]*', s)),
                sys.argv
            )
    )))

for parameter in parameters:
    logging.warning('Parameter: ' + parameter)
    exec(parameter)

In [5]:
# get masks from a given prompt
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
predictor = SamPredictor(sam)
input_image = cv2.imread(input_image_path)
predictor.set_image(input_image)

# convert input points and box into array
if input_points != "None":
    input_points_list = input_points.split(',')
    points = [[int(input_points_list[i]), int(input_points_list[i+1])] for i in range(0, len(input_points_list), 2)]
if input_box != "None":
    input_box_list = input_box.split(',')
    box = [int(x) for x in input_box_list]

# process output with different inputs mix
# obly points
if input_points != "None" and input_box == "None":
    print("points")
    input_label = np.array([1] * len(points))
    points = np.array(points)
    masks, scores, logits = predictor.predict(
        point_coords=points,
        point_labels=input_label,
        multimask_output=True,
    )

# only box
elif input_points == "None" and input_box != "None":
    box = np.array(box)
    masks, scores, logits = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=box[None, :],
        multimask_output=False,
    )
# both point and box
elif input_points != "None" and input_box != "None":
    input_label = np.array([1] * len(points))
    points = np.array(points)
    box = np.array(box)
    masks, scores, logits = predictor.predict(
        point_coords=points,
        point_labels=input_label,
        box=box,
        multimask_output=False,
    )

# save result into given directory
output_directory, output_filename = os.path.split(data_dir)
os.makedirs(output_directory, exist_ok=True)
np.save(data_dir + "get_masks", masks)
np.save(data_dir + "scores", scores)
np.save(data_dir + "logits", logits)






both
done
