## Setup

Let's make sure that we have access to GPU. We can use `nvidia-smi` command to do that. In case of any problems navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `GPU`, and then click `Save`.

In [None]:
!nvidia-smi

**NOTE:** To make it easier for us to manage datasets, images and models we create a `HOME` constant.

In [None]:
import os
HOME = os.getcwd()
print("HOME:", HOME)

## Install Grounding DINO and Segment Anything Model

Our project will use two groundbreaking designs - [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) - for zero-shot detection and [Segment Anything Model (SAM)](https://github.com/facebookresearch/segment-anything) - for converting boxes into segmentations. We have to install them first.


In [None]:
%cd {HOME}
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd {HOME}/GroundingDINO
!git checkout -q 57535c5a79791cb76e36fdb64975271354f10251
!pip install -q -e .

In [None]:
%cd {HOME}

import sys
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

**NOTE:** To glue all the elements of our demo together we will use the [`supervision`](https://github.com/roboflow/supervision) pip package, which will help us **process, filter and visualize our detections as well as to save our dataset**. A lower version of the `supervision` was installed with Grounding DINO. However, in this demo we need the functionality introduced in the latest versions. Therefore, we uninstall the current `supervsion` version and install version `0.6.0`.



In [None]:
!pip uninstall -y supervision
!pip install -q supervision

import supervision as sv
print(sv.__version__)

**NOTE:** At the end of the tutorial we will upload our annotations to [Roboflow](roboflow.com). To automate this process with the API, let's install the `roboflow` pip package.

In [None]:
!pip install -q roboflow
!pip install --upgrade gradio

### Download Grounding DINO Model Weights

To run Grounding DINO we need two files - configuration and model weights. The configuration file is part of the [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) repository, which we have already cloned. The weights file, on the other hand, we need to download. We write the paths to both files to the `GROUNDING_DINO_CONFIG_PATH` and `GROUNDING_DINO_CHECKPOINT_PATH` variables and verify if the paths are correct and the files exist on disk.

In [None]:
import os

GROUNDING_DINO_CONFIG_PATH = os.path.join(HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
print(GROUNDING_DINO_CONFIG_PATH, "; exist:", os.path.isfile(GROUNDING_DINO_CONFIG_PATH))

In [None]:
%cd {HOME}
!mkdir -p {HOME}/weights
%cd {HOME}/weights

!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth

In [None]:
import os

GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(HOME, "weights", "groundingdino_swint_ogc.pth")
print(GROUNDING_DINO_CHECKPOINT_PATH, "; exist:", os.path.isfile(GROUNDING_DINO_CHECKPOINT_PATH))

### Download Segment Anything Model (SAM) Weights

As with Grounding DINO, in order to run SAM we need a weights file, which we must first download. We write the path to local weight file to `SAM_CHECKPOINT_PATH` variable and verify if the path is correct and the file exist on disk.

In [None]:
%cd {HOME}
!mkdir -p {HOME}/weights
%cd {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
import os

SAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(SAM_CHECKPOINT_PATH, "; exist:", os.path.isfile(SAM_CHECKPOINT_PATH))

## Load models

In [None]:
import torch

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load Grounding DINO Model

In [None]:
%cd {HOME}/GroundingDINO

from groundingdino.util.inference import Model

grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)

### Load Segment Anything Model (SAM)

In [None]:
SAM_ENCODER_VERSION = "vit_h"

In [None]:
from segment_anything import sam_model_registry, SamPredictor

sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE)
sam_predictor = SamPredictor(sam)

In [None]:
f"{HOME}/data"

In [None]:
%cd {HOME}
!mkdir {HOME}/data
%cd {HOME}/data

!wget -q https://media.roboflow.com/notebooks/examples/dog.jpeg
!wget -q https://media.roboflow.com/notebooks/examples/dog-2.jpeg
!wget -q https://media.roboflow.com/notebooks/examples/dog-3.jpeg
!wget -q https://media.roboflow.com/notebooks/examples/dog-4.jpeg
!wget -q https://media.roboflow.com/notebooks/examples/dog-5.jpeg
!wget -q https://media.roboflow.com/notebooks/examples/dog-6.jpeg
!wget -q https://media.roboflow.com/notebooks/examples/dog-7.jpeg
!wget -q https://media.roboflow.com/notebooks/examples/dog-8.jpeg

## Explainable User Interface

In [None]:
import numpy as np
import cv2
from typing import List
from segment_anything import SamPredictor
import matplotlib.pyplot as plt

def enhance_class_name(class_names: List[str]) -> List[str]:
    return [
        f"all {class_name}s"
        for class_name
        in class_names
    ]

def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
    sam_predictor.set_image(image)
    result_masks = []
    for box in xyxy:
        masks, scores, logits = sam_predictor.predict(
            box=box,
            multimask_output=True
        )
        index = np.argmax(scores)
        result_masks.append(masks[index])
    return np.array(result_masks)

def upscale_nearest_neighbor(image, scale_factor):
    """Upscale an image using nearest-neighbor interpolation."""

    height, width, channels = image.shape
    new_height = int(height * scale_factor)
    new_width = int(width * scale_factor)

    upscaled_image = np.zeros((new_height, new_width, channels), dtype=image.dtype)

    for y in range(new_height):
        for x in range(new_width):
            orig_y = int(y / scale_factor)
            orig_x = int(x / scale_factor)
            upscaled_image[y, x] = image[orig_y, orig_x]

    return upscaled_image

def annotate_image(image, classes, detections):
    height, width = image.shape[:2]
    scale_factor = min(width, height)
    text_scale = 0.001 * scale_factor

    box_annotator = sv.BoxAnnotator()
    mask_annotator = sv.MaskAnnotator()
    label_annotator = sv.LabelAnnotator(text_scale=text_scale, text_thickness=3)

    labels = [
        f"{classes[class_id]} {confidence:0.2f}"
        for _, _, confidence, class_id, _, _
        in detections
    ]
    annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
    annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
    annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)

    %matplotlib inline
    sv.plot_image(annotated_image, (16, 16))

    return annotated_image

def segment_image(image, classes, box_threshold, text_threshold):
    detections = grounding_dino_model.predict_with_classes(
      image=image,
      classes=enhance_class_name(class_names=classes),
      box_threshold=box_threshold,
      text_threshold=text_threshold
    )

    detections = detections[detections.class_id != None]
    detections.mask = segment(
        sam_predictor=sam_predictor,
        image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
        xyxy=detections.xyxy
    )

    return detections

##

def segment_original_image(img_input, query, box_threshold, text_threshold):
    image = np.array(img_input)
    if image.shape[2] == 4:
        image = image[:, :, :3]

    CLASSES = query.split()
    BOX_THRESHOLD = box_threshold
    TEXT_THRESHOLD = text_threshold

    detections = segment_image(image, CLASSES, BOX_THRESHOLD, TEXT_THRESHOLD)

    annotated_image = annotate_image(image, CLASSES, detections)
    return annotated_image

##

def segment_selected_image(img_input, query, box_threshold, text_threshold, selected_pos=None):
    image = np.array(img_input)
    if image.shape[2] == 4:
        image = image[:, :, :3]

    CLASSES = query.split()
    BOX_THRESHOLD = box_threshold
    TEXT_THRESHOLD = text_threshold

    detections = segment_image(image, CLASSES, BOX_THRESHOLD, TEXT_THRESHOLD)

    selected_detections = []
    if selected_pos is not None:
      for xyxy, _, _, _, _, _ in detections:
        if int(float(xyxy[0].astype(float))) <= selected_pos[0] <= int(float(xyxy[2].astype(float))) and int(float(xyxy[1].astype(float))) <= selected_pos[1] <= int(float(xyxy[3].astype(float))):
          selected_detections.append(xyxy)

    selected_xyxy = None
    selected_size = 0
    if len(selected_detections) != 0:
      for xyxy in selected_detections:
        if selected_xyxy is None:
          selected_xyxy = xyxy
        else:
          curr_size = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1])
          min_size = (selected_xyxy[2] - selected_xyxy[0]) * (selected_xyxy[3] - selected_xyxy[1])
          if curr_size < min_size:
            selected_xyxy = xyxy
            selected_size = curr_size

    if selected_xyxy is not None:
      image = image[int(float(selected_xyxy[1]))-25:int(float(selected_xyxy[3]))+25,
              int(float(selected_xyxy[0]))-25:int(float(selected_xyxy[2]))+25, :3]
      image = upscale_nearest_neighbor(image, 5)
      detections = segment_image(image, CLASSES, BOX_THRESHOLD, TEXT_THRESHOLD)

      annotated_image = annotate_image(image, CLASSES, detections)
      return image, annotated_image

    annotated_image = annotate_image(image, CLASSES, detections)
    return image, annotated_image

##

import gradio as gr

css = """
/* CSS to make the image scale dynamically */
.gr-box .gr-image {
    width: 100%;
    height: auto;
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Segment Anything Model")

    with gr.Column():
        with gr.Row(equal_height=True):
            box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.35, step=0.01, label="Box Threshold")
            text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Text Threshold")
        textbox = gr.Textbox(lines=1, show_label=False, placeholder="What objects would you like to segment? (e.g. car, dog, ear)")

    with gr.Row(equal_height=True):
        image_input = gr.Image(type="pil", label="Upload an Image")
        image_output = gr.Image(type="pil", label="Segmented Image")

    def on_input_change(image, query, box_threshold, text_threshold):
        if image is not None:
            return segment_original_image(image, query, box_threshold, text_threshold)
        return None

    def on_img_select(image, query, box_threshold, text_threshold, evt: gr.SelectData):
        selected_pos = evt.index
        return segment_selected_image(image, query, box_threshold, text_threshold, selected_pos)

    image_input.change(on_input_change, inputs=[image_input, textbox, box_threshold, text_threshold], outputs=image_output)
    textbox.change(on_input_change, inputs=[image_input, textbox, box_threshold, text_threshold], outputs=image_output)
    box_threshold.change(on_input_change, inputs=[image_input, textbox, box_threshold, text_threshold], outputs=image_output)
    text_threshold.change(on_input_change, inputs=[image_input, textbox, box_threshold, text_threshold], outputs=image_output)

    image_output.select(on_img_select, inputs=[image_input, textbox, box_threshold, text_threshold], outputs=[image_input, image_output])

demo.launch(debug=True)