# segment-anything 

In [None]:
# !pip install git+https://github.com/facebookresearch/segment-anything.git

导入库,加载模型

In [1]:
import numpy as np
import gradio as gr
from PIL import Image, ImageDraw
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "/home/zjt/model/sam_vit_h_4b8939.pth"
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam.to("cuda")
predictor = SamPredictor(sam)

In [2]:

# 点击图片时画个点在上面
def on_click_image(image, global_state, evt: gr.SelectData):
    radius_scale= 0.01
    p_color = (255, 0, 0)
    xy = evt.index
    if global_state['img'] is None:
        global_state['img'] = image
    global_state['points'].append(xy)
    overlay_rgba = Image.new("RGBA", image.size, 0)
    overlay_draw = ImageDraw.Draw(overlay_rgba)
    rad_draw = int(image.size[0] * radius_scale)
    for p in global_state['points']:
        p_draw = int(p[0]), int(p[1])
        overlay_draw.ellipse(
            (
                p_draw[0] - rad_draw,
                p_draw[1] - rad_draw,
                p_draw[0] + rad_draw,
                p_draw[1] + rad_draw,
            ),
            fill=p_color,
        )
    image_draw =  Image.alpha_composite(image.convert("RGBA"), overlay_rgba).convert("RGB")
    return image_draw,global_state

# 推理时画上mask
def run_predict(img,global_state):
    image = np.array(global_state['img'])
    predictor.set_image(image)
    input_point = np.array(global_state['points'])
    input_label = np.array([1]*len(global_state['points']))
    if global_state['mask_input'] is None:
        masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )
        global_state['mask_input'] = logits[np.argmax(scores), :, :]
    else:
        masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        mask_input=global_state['mask_input'][None, :, :],
        multimask_output=False,
    )
    im_mask = np.uint8(masks[0] * 255)
    im_mask = Image.fromarray(im_mask)
    im_mask.save('tmp.png')
    image_draw = Image.blend(img, im_mask.convert('RGB'), 0.5)
    return image_draw,global_state
    
   
with gr.Blocks() as demo:
    global_state = gr.State({'img':None,'points':[],'mask_input':None})
    im = gr.Image(type='pil',value='/home/zjt/workspace/img2img/AnyDoor/examples/Gradio/FG/10036.jpg')
    
    im.select(
        on_click_image,
        inputs=[im, global_state],
        outputs=[im,global_state],
        queue=False,
    )
    btn = gr.Button()
    btn.click(fn=run_predict,inputs=[im,global_state],outputs=[im,global_state])
demo.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


