# Object masks from prompts with SAM

The Segment Anything Model (SAM) predicts object masks given prompts that indicate the desired object. The model first converts the image into an image embedding that allows high quality masks to be efficiently produced from a prompt. 

The `SamPredictor` class provides an easy interface to the model for prompting the model. It allows the user to first set an image using the `set_image` method, which calculates the necessary image embeddings. Then, prompts can be provided via the `predict` method to efficiently predict masks from those prompts. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction.

## Environment Set-up

If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_collab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'.

In [10]:
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
!{sys.executable} -m pip install opencv-python matplotlib
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

    
#!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

PyTorch version: 2.0.0+cu118
Torchvision version: 0.15.1+cu118
CUDA is available: True
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-0pumvf89
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-0pumvf89
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51
  Preparing metadata (setup.py) ... [?25l[?25hdone
--2023-04-13 19:37:11--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 13.35.8.29, 13.35.8.51, 13.35.8.35, ...
Connecting to dl.fbaipublicfiles.com (

## Set-up

Necessary imports and helper functions for displaying points, boxes, and masks.

In [11]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [12]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


# Chargement du modèle

In [13]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

# Chargement des données

 **TEST EN REPRENANT LES IMAGES DÉJÀ ANNOTÉES PAR YOLO**

In [14]:
!rm -rf lepidoptera
!git clone https://github.com/lucien92/lepidoptera/

Cloning into 'lepidoptera'...
remote: Enumerating objects: 248, done.[K
remote: Counting objects: 100% (248/248), done.[K
remote: Compressing objects: 100% (243/243), done.[K
remote: Total 248 (delta 5), reused 248 (delta 5), pack-reused 0[K
Receiving objects: 100% (248/248), 13.27 MiB | 18.56 MiB/s, done.
Resolving deltas: 100% (5/5), done.


In [16]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

from segment_anything import build_sam, SamPredictor, sam_model_registry

In [15]:
#####  parameters #####
csv_path = "/content/lepidoptera/segment_anything/result_2023-04-05 13:46:11.874005" #ici mettre le csv généré par le yolo (pour l'instant Amegilla quadrifasciata mais à remplacer par lépido)

sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda"
#####  parameters #####
!mkdir "/content/lepidoptera/segment_anything/output"
output_path = "/content/lepidoptera/segment_anything/output"

In [17]:
#####  util functions #####

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

#####  predict #####

def predict(img_path, sam_checkpoint, model_type, device, output_path, input_point): #, box

    image_name = (img_path.split(os.path.sep)[-1]).split('.')[0]
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)

    predictor = SamPredictor(sam)
    predictor.set_image(image)

    input_point = input_point
    input_label = np.array([1])

    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
        # box=box

    )

    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(100,100))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(input_point, input_label, plt.gca())
        # show_box(box, plt.gca())
        plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('on')
        plt.savefig(f"{output_path}/"+ image_name + ".png")

def read_csv(csv_path):

    with open(csv_path, "r") as f:

        img_paths = []
        img_bbox = []
        img_bbox_centers = []

        for line in f:
            line = line.split(",")

            img_path = line[0]
            img_paths.append(img_path)

            bbox= np.array([float(line[1]), float(line[2]), float(line[3]), float(line[4])])
            img_bbox.append(bbox)

            bbox_center = np.array([[(bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2]])
            img_bbox_centers.append(bbox_center)

    return img_paths, img_bbox, img_bbox_centers

In [None]:
!ls
if __name__ == "__main__":

    img_paths, img_bbox, img_bbox_centers = read_csv(csv_path)

    for i, img_path in enumerate(img_paths):

        bbox_center = img_bbox_centers[i]
        print(img_path)
        #bbox = img_bbox[i]
        predict(img_path, sam_checkpoint, model_type, device, output_path, bbox_center)

lepidoptera  sam_vit_b_01ec64.pth  sam_vit_h_4b8939.pth.1
sample_data  sam_vit_h_4b8939.pth
/content/lepidoptera/segment_anything/Amegilla quadrifasciata/Amegilla quadrifasciata48084.jpg
/content/lepidoptera/segment_anything/Amegilla quadrifasciata/Amegilla quadrifasciata80637.jpeg
/content/lepidoptera/segment_anything/Amegilla quadrifasciata/Amegilla quadrifasciata89995.jpg
/content/lepidoptera/segment_anything/Amegilla quadrifasciata/Amegilla quadrifasciata83743.jpg
/content/lepidoptera/segment_anything/Amegilla quadrifasciata/Amegilla quadrifasciata91602.jpg
/content/lepidoptera/segment_anything/Amegilla quadrifasciata/Amegilla quadrifasciata91602.jpg
/content/lepidoptera/segment_anything/Amegilla quadrifasciata/Amegilla quadrifasciata43569.jpeg
/content/lepidoptera/segment_anything/Amegilla quadrifasciata/Amegilla quadrifasciata46140.jpeg
/content/lepidoptera/segment_anything/Amegilla quadrifasciata/Amegilla quadrifasciata88608.jpeg
/content/lepidoptera/segment_anything/Amegilla qu

In [3]:
#on veut accéder aux résultats inscrits dans "/content/lepidoptera/segment_anything/output"
!zip -r lepidoptere.zip /content/lepidoptera/segment_anything/output

  adding: content/lepidoptera/segment_anything/output/ (stored 0%)
  adding: content/lepidoptera/segment_anything/output/Amegilla quadrifasciata91602.png (deflated 40%)
  adding: content/lepidoptera/segment_anything/output/Amegilla quadrifasciata48084.png (deflated 46%)
  adding: content/lepidoptera/segment_anything/output/Amegilla quadrifasciata80637.png (deflated 54%)
  adding: content/lepidoptera/segment_anything/output/Amegilla quadrifasciata88628.png (deflated 41%)
  adding: content/lepidoptera/segment_anything/output/Amegilla quadrifasciata88608.png (deflated 49%)
  adding: content/lepidoptera/segment_anything/output/Amegilla quadrifasciata89995.png (deflated 32%)
  adding: content/lepidoptera/segment_anything/output/Amegilla quadrifasciata46140.png (deflated 44%)
  adding: content/lepidoptera/segment_anything/output/Amegilla quadrifasciata43569.png (deflated 53%)
  adding: content/lepidoptera/segment_anything/output/Amegilla quadrifasciata36967.png (deflated 57%)
  adding: conte

In [4]:
!ls

lepidoptera	 sample_data	       sam_vit_h_4b8939.pth
lepidoptere.zip  sam_vit_b_01ec64.pth  sam_vit_h_4b8939.pth.1
