# Notebook adopted from `Roboflow` tutorials
https://github.com/roboflow/notebooks/blob/main/notebooks/how-to-segment-anything-with-sam.ipynb

# Segment Anything Model (SAM)

---

[![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/facebookresearch/segment-anything) [![arXiv](https://img.shields.io/badge/arXiv-2304.02643-b31b1b.svg)](https://arxiv.org/abs/2304.02643)

Segment Anything Model (SAM): a new AI model from Meta AI that can "cut out" any object, in any image, with a single click. SAM is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training. This notebook is an extension of the [official notebook](https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb) prepared by Meta AI.

![segment anything model](https://media.roboflow.com/notebooks/examples/segment-anything-model-paper.png)


## Pro Tip: Use GPU Acceleration

If you are running this notebook in Google Colab, navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `GPU`, and then click `Save`. This will ensure your notebook uses a GPU, which will significantly speed up model training times.

## Steps in this Tutorial

In this tutorial, we are going to cover:

- **Before you start** - Make sure you have access to the GPU
- Install Segment Anything Model (SAM)
- Download Example Data
- Load Model
- Automated Mask Generation
- Generate Segmentation with Bounding Box
- Segment Anything in Roboflow Universe Dataset

## Let's begin!

## Before you start

Let's make sure that we have access to GPU. We can use `nvidia-smi` command to do that. In case of any problems navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `GPU`, and then click `Save`.

In [None]:
!nvidia-smi

**NOTE:** To make it easier for us to manage datasets, images and models we create a `HOME` constant. 

In [None]:
import os
from pathlib import Path

HOME = os.getcwd()
home = Path(HOME)
print("HOME:", HOME)

# 1. Install Segment Anything Model (SAM) and other dependencies

In [None]:
%cd {HOME}

import sys
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

In [None]:
!pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision

## Download SAM weights

In [None]:
%cd {HOME}
!mkdir {HOME}/weights
%cd {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
import os

CHECKPOINT_PATH = home / 'weights' / 'sam_vit_h_4b8939.pth'
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

## Download Example Data

**NONE:** Let's download few example images. Feel free to use your images or videos.

In [None]:
%cd {HOME}
!mkdir {HOME}/sam_example
!git clone https://github.com/kurmukovai/sam_tutorial.git {HOME}/sam_example

## Load Model

In [None]:
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = 'vit_h'
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

# 2. Automated Mask Generation

To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended.

## Generate masks with SAM

In [None]:
import cv2
import supervision as sv

# IMAGE_NAME = 'berries.png'
IMAGE_NAME = 'xray.jpeg'
IMAGE_PATH = home / 'sam_example' / 'images' / IMAGE_NAME

image_bgr = cv2.imread(str(IMAGE_PATH))
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

In [None]:
mask_generator = SamAutomaticMaskGenerator(sam)
sam_result = mask_generator.generate(image_rgb)

### Output format

`SamAutomaticMaskGenerator` returns a `list` of masks, where each mask is a `dict` containing various information about the mask:

* `segmentation` - `[np.ndarray]` - the mask with `(W, H)` shape, and `bool` type
* `area` - `[int]` - the area of the mask in pixels
* `bbox` - `[List[int]]` - the boundary box of the mask in `xywh` format
* `predicted_iou` - `[float]` - the model's own prediction for the quality of the mask
* `point_coords` - `[List[List[float]]]` - the sampled input point that generated this mask
* `stability_score` - `[float]` - an additional measure of mask quality
* `crop_box` - `List[int]` - the crop of the image used to generate this mask in `xywh` format

## Results visualisation with Supervision

As of version `0.5.0` Supervision has native support for SAM.

In [None]:
masks = [
    mask['segmentation']
    for mask
    in sorted(sam_result, key=lambda x: x['area'], reverse=True)
]

print(sam_result[0].keys())
print(len(masks))

In [None]:
mask_annotator = sv.MaskAnnotator(color_map="index")

detections = sv.Detections.from_sam(sam_result=sam_result)

annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)

sv.plot_images_grid(
    images=[image_bgr, annotated_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image']
)

## Save masks to RGBA

In [None]:
from PIL import Image
import numpy as np

In [None]:
def save_mask(mask, file):
    rgb = np.repeat(np.zeros_like(mask)[..., np.newaxis], 3, axis=-1).astype(np.uint8)
    r, g, b = Image.fromarray(rgb).split()
    alpha = Image.fromarray(np.where(mask, 255, 0).astype(np.uint8))
    rgba = Image.merge('RGBA', (r, g, b, alpha))
    rgba.save(file)

In [None]:
for i, mask in enumerate(masks):
    save_mask(mask, f'mask_{i}.png')

# 3. Generate Segmentation with Bounding Box

The `SamPredictor` class provides an easy interface to the model for prompting the model. It allows the user to first set an image using the `set_image` method, which calculates the necessary image embeddings. Then, prompts can be provided via the `predict` method to efficiently predict masks from those prompts. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction.

In [None]:
# helper function that loads an image before adding it to the widget

import base64

def encode_image(filepath):
    with open(filepath, 'rb') as f:
        image_bytes = f.read()
    encoded = str(base64.b64encode(image_bytes), 'utf-8')
    return "data:image/jpg;base64,"+encoded

## Draw Box



**NOTE:** Execute cell below and use your mouse to draw bounding box on the image 👇

In [None]:
!pip install jupyter_bbox_widget

In [None]:
IMAGE_NAME = 'berries.png'
IMAGE_PATH = home / 'sam_example' / 'images' / IMAGE_NAME

In [None]:
IS_COLAB = True

if IS_COLAB:
    from google.colab import output
    output.enable_custom_widget_manager()

from jupyter_bbox_widget import BBoxWidget

widget = BBoxWidget()
widget.image = encode_image(IMAGE_PATH)
widget

In [None]:
widget.bboxes

## Generate masks with SAM

**NOTE:** `SamPredictor.predict` method takes `np.ndarray` `box` argument in `[x_min, y_min, x_max, y_max]` format. Let's reorganise your data first

In [None]:
image_bgr = cv2.imread(str(IMAGE_PATH))
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)


mask_predictor = SamPredictor(sam)
mask_predictor.set_image(image_rgb)

In [None]:
import numpy as np

box = widget.bboxes[0]
box = np.array([
    box['x'], 
    box['y'], 
    box['x'] + box['width'], 
    box['y'] + box['height']
])

masks, scores, logits = mask_predictor.predict(
    box=box,
    multimask_output=True
)

## Results visualisation with Supervision

In [None]:
box_annotator = sv.BoxAnnotator(color=sv.Color.red())
mask_annotator = sv.MaskAnnotator(color=sv.Color.red(), color_map="index")

detections = sv.Detections(
    xyxy=sv.mask_to_xyxy(masks=masks),
    mask=masks
)

print(len(detections))
print(detections.area)

detections = detections[detections.area == np.min(detections.area)]

source_image = box_annotator.annotate(scene=image_bgr.copy(), detections=detections, skip_label=True)
segmented_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)

sv.plot_images_grid(
    images=[source_image, segmented_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image']
)

## Interaction with segmentation results

In [None]:
import supervision as v

sv.plot_images_grid(
    images=masks,
    grid_size=(1, 4),
    size=(16, 4)
)

# 4. Run multiple prompts on a single image


**NOTE:** `SamPredictor` object has two methods: 
 - `set_image()` is a expensive and relatively slow feature generator (ViT based)
 - `predict()` is a lightweight and fast mask generator
 
 you only run `set_image()` once, after which you can run `predict()` with different prompts.

In [None]:
%%timeit

mask_predictor.set_image(image_rgb)

In [None]:
%%timeit

mask, _, _ = mask_predictor.predict(
    box=box,
    multimask_output=False
)