Segment-Anything

Get masks from an input image with input prompt. 
Output an numpy datafile in specified output directory.


In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python pycocotools matplotlib onnxruntime onnx

In [None]:
import logging
import os
import sys
import re
import cv2
import numpy as np
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator

In [None]:
# model type
model_type = os.environ.get("model_type", "vit_h") # default, vit_h, vit_l or vit_b

# checkpoint
# different model type requires different check point
checkpoint_path = os.environ.get("checkpoint_path")

# input image
input_image_path = os.environ.get("input_image_path")

# input prompts
input_prompt = os.environ.get("input_prompt")

# temporal data storage for local execution
data_dir = os.environ.get('data_dir', '../../data/')

# dummy_output (to be fixed once C3 supports < 1 outputs)
# output_dummy = os.environ.get('output_dummy', 'output_dummy')

In [None]:
parameters = list(
    map(lambda s: re.sub('$', '"', s),
        map(
            lambda s: s.replace('=', '="'),
            filter(
                lambda s: s.find('=') > -1 and bool(re.match(r'[A-Za-z0-9_]*=[.\/A-Za-z0-9]*', s)),
                sys.argv
            )
    )))

for parameter in parameters:
    logging.warning('Parameter: ' + parameter)
    exec(parameter)

In [None]:
# get masks from a given prompt
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
predictor = SamPredictor(sam)
input_image = cv2.imread(input_image_path)
predictor.set_image(input_image)

prompts = input_prompt.split()
masks, _, _ = predictor.predict(prompts)

# save result into given directory
output_directory, output_filename = os.path.split(data_dir)
os.makedirs(output_directory, exist_ok=True)
np.save(data_dir, masks)



