In [None]:
%matplotlib ipympl
import os

import common.torch as torch
import requests
import cv2
import supervision as sv
from supervision.draw.color import ColorPalette
import numpy as np
from PIL import Image

import matplotlib.pyplot as plt

import experimental.overhead_matching.grounding_sam as gs
import importlib
importlib.reload(gs)



In [None]:
model = gs.GroundingSam()

image_url = "https://s3.us-east-1.amazonaws.com/images.cocodataset.org/train2017/000000146439.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
img = np.array(image.convert('RGB'))

plt.figure()
plt.imshow(img)

# VERY important: text queries need to be lowercased + end with a dot
queries = ['cat', 'keyboard', 'piano', 'speaker']

results = model.detect_queries(image = img, queries = queries)


In [None]:
print(results['dino_results'])

results['dino_results']['boxes'][0].tolist()


In [None]:

confidences = results['dino_results']["scores"].tolist()
class_names = results['dino_results']["labels"]
class_ids = np.array(list(range(len(class_names))))

labels = [
    f"{class_name} {confidence:.2f}"
    for class_name, confidence
    in zip(class_names, confidences)
]

"""
Visualize image with supervision useful API
"""
# img = cv2.imread(img_path)
detections = sv.Detections(
    xyxy=results['dino_results']['boxes'],  # (n, 4)
    mask=results['sam_results']['masks'].astype(bool),  # (n, h, w)
    class_id=class_ids
)

print(detections)

"""
Note that if you want to use default color map,
you can set color=ColorPalette.DEFAULT
"""

bgr_img = np.array(image.convert('RGB'))
OUTPUT_DIR = '/tmp'
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=bgr_img.copy(),
                 detections=detections)

plt.figure()
plt.imshow(annotated_frame)

label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=bgr_img.copy(),
                           detections=detections, labels=labels)
# cv2.imwrite(os.path.join(OUTPUT_DIR,
# "groundingdino_annotated_image.jpg"), annotated_frame)


plt.figure()
plt.imshow(annotated_frame)
# cv2.imwrite(os.path.join(OUTPUT_DIR,
# "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)


In [None]:
plt.figure()
for i in range(len(detections)):
    d = detections[i]
    plt.subplot(2, 2, i+1)
    plt.imshow(d.mask.squeeze())
    plt.colorbar()
    print(d.xyxy)

In [None]:
d.mask[0].nonzero()