In [1]:
# import whisper
import whisper
# import groundingdino
from groundingdino.util.inference import load_model, load_image, predict, annotate
import cv2
# bounding box imports
import numpy as np
import torch
from torchvision.ops import box_convert
import matplotlib.pyplot as plt
# sam imports
from segment_anything import SamPredictor, sam_model_registry

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
### WHISPER
def speechToText(audio_path):
    model = whisper.load_model("base.en")
    print("Whisper loaded.")
    
    AUDIO_PATH = audio_path
    result = model.transcribe(AUDIO_PATH)
    prompt = result["text"]
    print("Prompt: ", prompt)
    return prompt

In [3]:
### GROUNDINGDINO
def getBoundingBox(prompt, image_path, output_path, tasks):
    model = load_model("../GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "../GroundingDINO/weights/groundingdino_swint_ogc.pth")
    print("GroundingDINO loaded.")

    IMAGE_PATH = image_path
    TEXT_PROMPT = prompt
    BOX_TRESHOLD = 0.35
    TEXT_TRESHOLD = 0.25

    image_source, image = load_image(IMAGE_PATH)
    print("Image loaded.")

    boxes, logits, phrases = predict(
        model=model,
        image=image,
        caption=TEXT_PROMPT,
        box_threshold=BOX_TRESHOLD,
        text_threshold=TEXT_TRESHOLD,
        device="cpu"
    )

    if "annotate" in tasks:
        annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
        cv2.imwrite(output_path, annotated_frame)
        print("Box drawn.")

    all_coords = getCoords(image_source=image_source, boxes=boxes)
    # get box with highest score
    coords = all_coords[0]
    print("Coordinates: ", coords)
    return coords

def getCoords(image_source: np.ndarray, boxes: torch.Tensor):
    h, w, _ = image_source.shape
    boxes = boxes * torch.Tensor([w, h, w, h])
    xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
    return xyxy

In [4]:
### SAM
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def loadSAM():
    sam_checkpoint = "../SAM/checkpoints/sam_vit_h_4b8939.pth"
    model_type = "vit_h"

    is_cuda = torch.cuda.is_available()

    # device = "cuda" if is_cuda else "cpu"
    if is_cuda: torch.cuda.empty_cache()

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    # sam.to(device=device)
    print("SAM loaded.")
    return sam

def getObjectMask(sam, image_path, coords):
    image_cv2 = cv2.imread(image_path)
    image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)

    predictor = SamPredictor(sam)
    predictor.set_image(image_cv2)
    print("Predictor set.")

    # input_box = np.array(coords)
    input_box = coords

    masks, _, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=False,
    )

    plt.figure(figsize=(8, 8))
    plt.imshow(image_cv2)
    show_mask(masks[0], plt.gca())
    show_box(input_box, plt.gca())
    plt.axis('off')
    plt.show()

In [5]:
### MAIN PROGRAM
# filename = input("Filename: ")
# from_audio = input("Use audio file? (Y/N)")

# if from_audio == "Y":
#     audio_path = "images/" + filename + ".mp3"
#     prompt = speechToText(audio_path)
# else:
#     prompt = input("Text prompt: ")

filename = "8259"
prompt = "person on the left"

# get_mask = input("Get mask? (Y/N)")

image_path = filename + ".png"
output_path = filename + "-annotated.jpg"

coords = getBoundingBox(prompt, image_path, output_path, tasks=[])

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


GroundingDINO loaded.
Image loaded.




Coordinates:  [1113.5714   454.08368 1153.2848   548.22864]


In [6]:
image_cv2 = cv2.imread(image_path)
image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)

sam = loadSAM()

SAM loaded.


In [7]:
predictor = SamPredictor(sam)
predictor.set_image(image_cv2)
print("Predictor set.")

Predictor set.


In [None]:
# input_box = np.array(coords)
input_box = coords

masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)

plt.figure(figsize=(8, 8))
plt.imshow(image_cv2)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()