<a href="https://colab.research.google.com/github/jyotidabass/GroundingSAM-Gradio-App/blob/main/GroundingSAM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#!pip install spaces
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import torch
from transformers import SamModel, SamProcessor
import spaces
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model_id = "IDEA-Research/grounding-dino-base"

dino_processor = AutoProcessor.from_pretrained(model_id)
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)

def infer_dino(img, text_queries, score_threshold):
  queries=""
  for query in text_queries:
    queries += f"{query}. "

  width, height = img.shape[:2]

  target_sizes=[(width, height)]
  inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)

  with torch.no_grad():
    outputs = dino_model(**inputs)
    outputs.logits = outputs.logits.cpu()
    outputs.pred_boxes = outputs.pred_boxes.cpu()
    results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
                                                                  box_threshold=score_threshold,
                                                                  target_sizes=target_sizes)
  return results


@spaces.GPU
def query_image(img, text_queries, dino_threshold):
  text_queries = text_queries
  text_queries = text_queries.split(",")
  dino_output = infer_dino(img, text_queries, dino_threshold)
  result_labels=[]
  for pred in dino_output:
    boxes = pred["boxes"].cpu()
    scores = pred["scores"].cpu()
    labels = pred["labels"]
    box = [torch.round(pred["boxes"][0], decimals=2), torch.round(pred["boxes"][1], decimals=2),
        torch.round(pred["boxes"][2], decimals=2), torch.round(pred["boxes"][3], decimals=2)]
    for box, score, label in zip(boxes, scores, labels):
      if label != "":
        inputs = sam_processor(
                img,
                input_boxes=[[[box]]],
                return_tensors="pt"
            ).to("cuda")

        with torch.no_grad():
            outputs = sam_model(**inputs)

        mask = sam_processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(),
            inputs["original_sizes"].cpu(),
            inputs["reshaped_input_sizes"].cpu()
        )[0][0][0].numpy()
        mask = mask[np.newaxis, ...]
        result_labels.append((mask, label))
  return img, result_labels

import gradio as gr

description = "This Space combines [GroundingDINO](https://huggingface.co/IDEA-Research/grounding-dino-base), a bleeding-edge zero-shot object detection model with [SAM](https://huggingface.co/facebook/sam-vit-base), the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
demo = gr.Interface(
    query_image,
    inputs=[gr.Image(label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold for GroundingDINO")],
    outputs="annotatedimage",
    title="GroundingDINO 🤝 SAM for Zero-shot Segmentation",
    description=description,
    examples=[
        ["./cats.png", "cat, fishnet", 0.16],["./bee.jpg", "bee, flower", 0.16]
    ],
)
demo.launch(debug=True)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://ed4e9771b82e5de76d.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/gradio/queueing.py", line 527, in process_events
    response = await route_utils.call_process_api(
  File "/usr/local/lib/python3.10/dist-packages/gradio/route_utils.py", line 261, in call_process_api
    output = await app.get_blocks().process_api(
  File "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py", line 1788, in process_api
    result = await self.call_function(
  File "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py", line 1340, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
    return await future
  File "/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py", line 807, in run
    r

Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://ed4e9771b82e5de76d.gradio.live


