# Segment Anything Image Segmentator

In [None]:
#@title Install requirements {display-mode: "form"}
!pip install -q gradio git+https://github.com/facebookresearch/segment-anything.git
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [2]:
#@title Initialize the model {display-mode: "form"}
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from PIL import Image
import numpy as np
import torch
import gc

sam = sam_model_registry["vit_h"](checkpoint="./sam_vit_h_4b8939.pth")
mask_generator = SamAutomaticMaskGenerator(sam)

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    h, w =  anns[0]['segmentation'].shape
    final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
        for i in range(3):
            img[:,:,i] = np.random.randint(255, dtype=np.uint8)
        final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m*255)))
    
    return final_img

def segment_image(image):
  # Generate Masks
  sam.to("cuda")
  masks = mask_generator.generate(image)
  sam.to("cpu")
  torch.cuda.empty_cache()
  # Create map
  map = show_anns(masks)
  del masks
  gc.collect()
  return map

In [None]:
#@title Run the Segmentator {display-mode: "form"}
import gradio as gr

demo = gr.Interface(
    fn=segment_image,
    title="🙌 Segment Anything",
    description = """⬆️ Upload your reference photo here, then click on Send. \n \n ⬇️ Download the generated Segmentation Map and use it as the Input for ControlNet Segment Anything""",
    inputs = gr.Image(shape=(512, 512), label = "Input Image"),
    outputs = gr.Image(shape=(512, 512), label="Segmentation Map"),
)

demo.launch(debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>